Я делаю сегментацию. В каждой обучающей выборке есть несколько изображений с масками сегментации. Я пытаюсь написать input_fn
, чтобы объединить все изображения масок в одно для каждой обучающей выборки. Я планировал использовать два Datasets
, один из которых выполняет итерацию по папкам с образцами, а другой считывает все маски как один большой пакет, а затем объединяет их в один тензор.
Я получаю сообщение об ошибке при вызове вложенного make_one_shot_iterator
. Я знаю, что этот подход немного натянут и, скорее всего, наборы данных не предназначены для такого использования. Но как мне тогда подойти к этой проблеме, чтобы избежать использования tf.py_func?
Вот упрощенная версия набора данных:
def read_sample(sample_path):
masks_ds = (tf.data.Dataset.
list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects
masks = masks_ds.make_one_shot_iterator().get_next()
return tf.reduce_max(masks, axis=0)
ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...