Есть ли способ использовать tf.data.Dataset внутри другого набора данных в Tensorflow?

Я делаю сегментацию. В каждой обучающей выборке есть несколько изображений с масками сегментации. Я пытаюсь написать input_fn, чтобы объединить все изображения масок в одно для каждой обучающей выборки. Я планировал использовать два Datasets, один из которых выполняет итерацию по папкам с образцами, а другой считывает все маски как один большой пакет, а затем объединяет их в один тензор.

Я получаю сообщение об ошибке при вызове вложенного make_one_shot_iterator. Я знаю, что этот подход немного натянут и, скорее всего, наборы данных не предназначены для такого использования. Но как мне тогда подойти к этой проблеме, чтобы избежать использования tf.py_func?

Вот упрощенная версия набора данных:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.
        list_files(sample_path+"/masks/*.png")
        .map(tf.read_file)
        .map(lambda x: tf.image.decode_image(x, channels=1))
        .batch(1024)) # maximum number of objects
    masks = masks_ds.make_one_shot_iterator().get_next()

    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

person Piotr Czapla    schedule 27.02.2018    source источник


Ответы (1)


Если вложенный набор данных содержит только один элемент, вы можете использовать _1 _ во вложенном наборе данных вместо создания итератора:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
                .map(tf.read_file)
                .map(lambda x: tf.image.decode_image(x, channels=1))
                .batch(1024)) # maximum number of objects
    masks = tf.contrib.data.get_single_element(masks_ds)
    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

Кроме того, вы можете использовать tf.data.Dataset.flat_map(), _ 4_ или _ 5_ преобразованиеw для выполнения вложенных Dataset вычислений внутри функции и сглаживания результата в один Dataset. Например, чтобы получить все образцы в одном Dataset:

def read_all_samples(sample_path):
    return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
            .map(tf.read_file)
            .map(lambda x: tf.image.decode_image(x, channels=1))
            .batch(1024)) # maximum number of objects

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()
person mrry    schedule 28.02.2018