TensorFlow tf.data.Dataset и ведение

Для сети LSTM я заметил большие улучшения с ведением.

Я наткнулся на раздел сегментирования в документации TensorFlow, который (tf .contrib).

Хотя в своей сети я использую tf.data.Dataset API, в частности, я работаю с TFRecords, поэтому мой входной конвейер выглядит примерно так

dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})

Как я могу включить метод сегментирования в конвейер tf.data.Dataset?

Если это важно, в каждой записи в файле TFRecords у меня есть длина последовательности, сохраненная как целое число.


person bluesummers    schedule 30.05.2018    source источник


Ответы (1)


Различные bucketing варианты использования с Dataset API хорошо объяснены здесь.

bucket_by_sequence_length() пример:

def elements_gen():
   text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
   label = [1, 2, 1, 2]
   for x, y in zip(text, label):
       yield (x, y)

def element_length_fn(x, y):
   return tf.shape(x)[0]

dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                     output_shapes=([None],[]),
                                     output_types=(tf.int32, tf.int32))

dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                              bucket_batch_sizes=[2, 2, 2],
                                                              bucket_boundaries=[0, 8]))

batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:

   for _ in range(2):
      print('Get_next:')
      print(sess.run(batch))

Вывод:

Get_next:
(array([[1, 2, 3, 0, 0],
   [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
Get_next:
(array([[1, 2, 0, 0],
   [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))
person vijay m    schedule 30.05.2018
comment
В моем случае использования на самом деле есть много функций, и одна из них представляет собой последовательность, скажем, ее x['seq'] в каждой записи, как мне применить ее только к этому элементу? - person bluesummers; 31.05.2018
comment
вам нужно изменить вашу elements_gen() функцию на yield(x['seq'], y) - person vijay m; 31.05.2018
comment
У меня нет element_gen(), так как я читаю из файла TFRecords, не могу ли я изменить element_length_func, чтобы он возвращал tf.shape(x['seq'])[0]? А почему вы позвонили del y? - person bluesummers; 31.05.2018
comment
Ссылка ведет на удаленный сайт. - person Pius Friesch; 10.10.2018