Первый случай: мы хотим, чтобы результат имел фиксированный размер партии
В этом случае генератор генерирует значения формы [None, 48, 48, 3]
, где первое измерение может быть любым. Мы хотим сгруппировать это так, чтобы на выходе было [batch_size, 48, 48, 3]
. Если мы используем напрямую tf.data.Dataset.batch
, у нас будет ошибка, поэтому сначала нам нужно разблокировать.
Для этого мы можем использовать tf.contrib.data.unbatch
, например, перед пакетной обработкой:
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.batch(batch_size)
Вот полный пример, в котором генератор выдает [1]
, [2, 2]
, [3, 3, 3]
и [4, 4, 4, 4]
.
Мы не можем напрямую пакетировать эти выходные значения, поэтому мы разупаковываем, а затем группируем их:
def gen():
for i in range(1, 5):
yield [i] * i
# Create dataset from generator
# The output shape is variable: (None,)
dataset = tf.data.Dataset.from_generator(gen, tf.int64, tf.TensorShape([None]))
# The issue here is that we want to batch the data
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.batch(2)
# Create iterator from dataset
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next() # shape (None,)
sess = tf.Session()
for i in range(5):
print(sess.run(x))
Это напечатает следующий вывод:
[1 2]
[2 3]
[3 3]
[4 4]
[4 4]
Второй случай: мы хотим объединить партии переменного размера
Обновление (30.03.2018): я удалил предыдущий ответ, в котором использовалось сегментирование, которое значительно снижает производительность (см. комментарии).
В этом случае мы хотим объединить фиксированное количество пакетов. Проблема в том, что эти партии имеют переменные размеры. Например, набор данных дает [1]
и [2, 2]
, и мы хотим получить [1, 2, 2]
в качестве вывода.
Вот быстрый способ решить эту проблему - создать новый генератор, обернутый вокруг исходного. Новый генератор будет выдавать пакетные данные. (Спасибо Гийому за идею)
Вот полный пример, где генератор дает [1]
, [2, 2]
, [3, 3, 3]
и [4, 4, 4, 4]
.
def gen():
for i in range(1, 5):
yield [i] * i
def get_batch_gen(gen, batch_size=2):
def batch_gen():
buff = []
for i, x in enumerate(gen()):
if i % batch_size == 0 and buff:
yield np.concatenate(buff, axis=0)
buff = []
buff += [x]
if buff:
yield np.concatenate(buff, axis=0)
return batch_gen
# Create dataset from generator
batch_size = 2
dataset = tf.data.Dataset.from_generator(get_batch_gen(gen, batch_size),
tf.int64, tf.TensorShape([None]))
# Create iterator from dataset
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next() # shape (None,)
with tf.Session() as sess:
for i in range(2):
print(sess.run(x))
Это напечатает следующий вывод:
[1 2 2]
[3 3 3 4 4 4 4]
person
Olivier Moindrot
schedule
28.03.2018