API набора данных Tensorflow с ListDirectory

Я хочу использовать Tensorflow Dataset API для создания одного пакета для каждой папки (каждая папка, содержащая изображения). У меня есть следующий простой фрагмент кода:

import tensorflow as tf
import os
import pdb

def parse_file(filename):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_png(image_string)
    image_resized = tf.image.resize_images(image_decoded, [48, 48])
    return image_resized #, label

def parse_dir(frame_dir):
    filenames = tf.gfile.ListDirectory(frame_dir)
    batch = tf.constant(5)
    batch = tf.map_fn(parse_file, filenames)
    return batch

directory = "../Detections/NAC20171125"
# filenames = tf.constant([os.path.join(directory, f) for f in os.listdir(directory)])
frames = [os.path.join(directory, str(f)) for f in range(10)]


dataset = tf.data.Dataset.from_tensor_slices((frames))
dataset = dataset.map(parse_dir)

dataset = dataset.batch(256)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()


with tf.Session() as sess:
    sess.run(iterator.initializer)
    while True:
        try:
            batch = sess.run(next_element)
            print(batch.shape)
        except tf.errors.OutOfRangeError:
            break

Однако tf.gfile.ListDirectory (в parse_dir) ожидает обычную строку вместо Tensor. Итак, теперь ошибка

TypeError: Expected binary or unicode string, got <tf.Tensor 'arg0:0' shape=() dtype=string>

Есть простой способ решить эту проблему?


person Derk    schedule 19.12.2017    source источник


Ответы (1)


Проблема здесь в том, что tf.gfile.ListDirectory() - это функция Python, которая ожидает Python строка, а аргумент frame_dir для parse_dir() - это tf.Tensor. Поэтому вам потребуется эквивалентная операция TensorFlow для вывода списка файлов в каталоге и tf.data.Dataset.list_files() (на основе tf.matching_files()), вероятно, является ближайшим эквивалентом.

directory = "../Detections/NAC20171125"
frames = [os.path.join(directory, str(f)) for f in range(10)]

# Start with a dataset of directory names.
dataset = tf.data.Dataset.from_tensor_slices(frames)

# Maps each subdirectory to the list of files in that subdirectory and flattens
# the result.
dataset = dataset.flat_map(lambda dir: tf.data.Dataset.list_files(dir + "/*"))

# Maps each filename to the parsed and resized image data.
dataset = dataset.map(parse_file)

dataset = dataset.batch(256)

iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
person mrry    schedule 19.12.2017