Как заменить ввод сохраненного графика, например. заполнитель с помощью итератора набора данных?

У меня есть сохраненный график Tensorflow, который использует ввод через placeholder с параметром feed_dict.

sess.run(my_tensor, feed_dict={input_image: image})

Поскольку подача данных с помощью Dataset Iterator более эффективна, я хочу загрузить сохраненный график, заменить input_image placeholder с Iterator и бегом. Как я могу это сделать? Есть ли лучший способ сделать это? Ответ с примером кода будет высоко оценен.


person Sam    schedule 16.05.2018    source источник


Ответы (1)


Вы можете добиться этого, сериализовав свой график и повторно импортировав его, используя tf.import_graph_def, у которого есть аргумент input_map, используемый для вставки входных данных в нужных местах.

Для этого вам нужно как минимум знать имена входов, которые вы заменяете, и выходов, которые вы хотите выполнить (соответственно x и y в моих примерах).

import tensorflow as tf

# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')

# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
  print("with placeholder:")
  for i in range(10):
    print(sess.run(y, {x: i}))

# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()

tf.reset_default_graph()

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])

# enjoy Dataset inputs!
with tf.Session() as sess:
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        

Обратите внимание, что узел-заполнитель все еще там, поскольку я не удосужился разобрать здесь graph_def, чтобы удалить его — вы можете удалить его в качестве улучшения, хотя я думаю, что оставить его здесь тоже можно.

В зависимости от того, как вы восстанавливаете график, замена ввода может быть уже встроена в загрузчик, что упрощает задачу (нет необходимости возвращаться к GraphDef). Например, если вы загружаете график из файла .meta, вы можете использовать tf.train.import_meta_graph, который принимает тот же аргумент input_map.

import tensorflow as tf

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')

# enjoy Dataset inputs!
with tf.Session() as sess:
  # not needed here, but in practice you would also need to restore weights
  # restorer.restore(sess, weights_filepath)
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        
person P-Gn    schedule 16.05.2018
comment
Можете ли вы дать представление о том, как полностью удалить заполнитель? - person lamo_738; 23.07.2019