Для сети LSTM я заметил большие улучшения с ведением.
Я наткнулся на раздел сегментирования в документации TensorFlow, который (tf .contrib).
Хотя в своей сети я использую tf.data.Dataset
API, в частности, я работаю с TFRecords, поэтому мой входной конвейер выглядит примерно так
dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})
Как я могу включить метод сегментирования в конвейер tf.data.Dataset
?
Если это важно, в каждой записи в файле TFRecords у меня есть длина последовательности, сохраненная как целое число.