tensorflow cifar10 возобновляет обучение из файла контрольной точки

При использовании Tensorflow я пытаюсь возобновить обучение CIFAR10, используя файл с контрольной точкой. Ссылаясь на некоторые другие статьи, я безуспешно пробовал tf.train.Saver().restore. Может ли кто-нибудь пролить свет на то, как действовать?

Фрагмент кода из Tensorflow CIFAR10

def train():
  # methods to build graph from the cifar10_train.py
  global_step = tf.Variable(0, trainable=False)
  images, labels = cifar10.distorted_inputs()
  logits = cifar10.inference(images)
  loss = cifar10.loss(logits, labels)
  train_op = cifar10.train(loss, global_step)
  saver = tf.train.Saver(tf.all_variables())
  summary_op = tf.merge_all_summaries()

  init = tf.initialize_all_variables() 
  sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
  sess.run(init)


  print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir)

  if FLAGS.checkpoint_dir is None:
    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
  else:
    # restoring from the checkpoint file
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)

  # cur_step prints out well with the checkpointed variable value
  cur_step = sess.run(global_step);
  print("current step is %s" % cur_step)

  for step in xrange(cur_step, FLAGS.max_steps):
    start_time = time.time()
    # **It stucks at this call **
    _, loss_value = sess.run([train_op, loss])
    # below same as original

person emerson    schedule 31.05.2016    source источник


Ответы (1)


Проблема, похоже, в том, что эта строка:

tf.train.start_queue_runners(sess=sess)

...выполняется, только если FLAGS.checkpoint_dir is None. Вам все равно нужно будет запускать обработчиков очередей, если вы восстанавливаетесь с контрольной точки.

Обратите внимание, что я бы рекомендовал запускать обработчики очередей после создания tf.train.Saver (из-за состояния гонки в выпущенной версии кода), поэтому лучшей структурой будет следующая:

if FLAGS.checkpoint_dir is not None:
  # restoring from the checkpoint file
  ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
  tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)

# Start the queue runners.
tf.train.start_queue_runners(sess=sess)

# ...

for step in xrange(cur_step, FLAGS.max_steps):
  start_time = time.time()
  _, loss_value = sess.run([train_op, loss])
  # ...
person mrry    schedule 31.05.2016
comment
Спасибо за твой ответ! Это решило проблему. Я думал, что queue_runner отвечает за создание входного изображения (путем искажения), и это не обязательный шаг, поскольку я восстанавливаю из файла контрольной точки. - person emerson; 02.06.2016