Передача обучения с помощью tf.estimator.Estimator framework

Я пытаюсь передать обучение модели Inception-resnet v2, предварительно обученной в imagenet, используя мой собственный набор данных и классы. Моя первоначальная кодовая база была модификацией образца tf.slim, который я больше не могу найти, и теперь я пытаюсь переписать тот же код, используя структуру tf.estimator.*.

Однако я сталкиваюсь с проблемой загрузки только некоторых весов из предварительно обученной контрольной точки, инициализируя оставшиеся слои их инициализаторами по умолчанию.

Исследуя проблему, я нашел эту проблему GitHub и этот вопрос, оба упоминают о необходимости использовать tf.train.init_from_checkpoint в моем model_fn. Я пытался, но, учитывая отсутствие примеров в обоих, думаю, что-то не так.

Это мой минимальный пример:

import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np

import inception_resnet_v2

NUM_CLASSES = 900
IMAGE_SIZE = 299

def input_fn(mode, num_classes, batch_size=1):
  # some code that loads images, reshapes them to 299x299x3 and batches them
  return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)


def model_fn(images, labels, num_classes, mode):
  with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
    logits, end_points = inception_resnet_v2.inception_resnet_v2(images,
                                             num_classes, 
                                             is_training=(mode==tf.estimator.ModeKeys.TRAIN))
  predictions = {
      'classes': tf.argmax(input=logits, axis=1),
      'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
  }

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
  variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
  scopes = { os.path.dirname(v.name) for v in variables_to_restore }
  tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
                                {s+'/':s+'/' for s in scopes})
  
  tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  total_loss = tf.losses.get_total_loss()    #obtain the regularization losses as well
  
  # Configure the training op
  if mode == tf.estimator.ModeKeys.TRAIN:
    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)
    train_op = optimizer.minimize(total_loss, global_step)
  else:
    train_op = None
  
  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op)

def main(unused_argv):
  # Create the Estimator
  classifier = tf.estimator.Estimator(
      model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
      model_dir='model/MCVE')

  # Train the model  
  classifier.train(
      input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),
      steps=1000)
    
  # Evaluate the model and print results
  eval_results = classifier.evaluate(
      input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))
  print()
  print('Evaluation results:\n    %s' % eval_results)
 
if __name__ == '__main__':
  tf.app.run(main=main, argv=[sys.argv[0]])

где inception_resnet_v2 — это реализация модели в репозитории моделей Tensorflow.

Если я запускаю этот скрипт, я получаю кучу информации из журнала init_from_checkpoint, но затем во время создания сеанса кажется, что он пытается загрузить веса Logits из контрольной точки и терпит неудачу из-за несовместимых форм. Это полная трассировка:

Traceback (most recent call last):

  File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>
    runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
    execfile(filename, namespace)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]])

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))

  File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main
    steps=1000)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model
    log_step_count_steps=self._config.log_step_count_steps) as mon_sess:

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__
    stop_grace_period_secs=stop_grace_period_secs)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__
    _WrappedSession.__init__(self, self._create_session())

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session
    return self._sess_creator.create_session()

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session
    self.tf_sess = self._session_creator.create_session()

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session
    init_fn=self._scaffold.init_fn)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session
    sess.run(init_op, feed_dict=init_feed_dict)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
    run_metadata_ptr)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
    options, run_metadata)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
    raise type(e)(node_def, op, message)

InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001]    [[Node: Assign_1145 = Assign[T=DT_FLOAT,
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]

Что я делаю неправильно при использовании init_from_checkpoint? Как именно мы должны использовать его в нашем model_fn? И почему оценщик пытается загрузить веса Logits' из контрольной точки, когда я прямо говорю ему не делать этого?

Обновлять:

После предложения в комментариях я попробовал альтернативные способы вызова tf.train.init_from_checkpoint.

Использование {v.name: v.name}

Если, как предлагается в комментарии, я заменяю вызов на {v.name:v.name for v in variables_to_restore}, я получаю эту ошибку:

ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.

Использование {v.name: v}

Если вместо этого я попытаюсь использовать сопоставление name:variable, я получу следующую ошибку:

ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in
inception_resnet_v2_2016_08_30.ckpt checkpoint
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256], 
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...

Ошибка продолжает перечислять то, что я думаю, все имена переменных в контрольной точке (или это могут быть области видимости?).

Обновление (2)

Изучив последнюю ошибку здесь выше, я вижу, что InceptionResnetV2/Conv2d_2a_3x3/weights является в списке переменных с контрольными точками. Проблема в том, что :0 в конце! Сейчас я проверю, действительно ли это решает проблему, и опубликую ответ, если это так.


person GPhilo    schedule 18.12.2017    source источник
comment
Есть ли контрольные точки в каталоге оценщика model/MCVE?   -  person kww    schedule 19.12.2017
comment
Нет, каталог пуст   -  person GPhilo    schedule 19.12.2017
comment
Возможно, строка scopes = { os.path.dirname(v.name) for v in variables_to_restore } добавляет InceptionResnetV2 в список областей действия, поэтому все переменные под InceptionResnetV2/ будут загружены. Вместо того, чтобы создавать список областей, вы можете попробовать перечислить переменные напрямую: tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', {v.name:v.name for v in variables})   -  person kww    schedule 19.12.2017
comment
Это возможно, да. Однако, если я попытаюсь использовать предложенный вами код, я получу эту ошибку: ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.. Кажется, имена переменных должны использоваться по-другому   -  person GPhilo    schedule 19.12.2017
comment
Если вы делаете полный переход с slim, рассмотрите возможность использования tf.contrib.framework.get_variables_to_restore. Это похоже, но просто вопрос бухгалтерии (раздражает).   -  person Varun    schedule 09.06.2018


Ответы (2)


Благодаря комментарию @KathyWu я встал на правильный путь и нашел проблему.

Действительно, способ, которым я вычислял scopes, включал область InceptionResnetV2/, которая инициировала загрузку всех переменных "под" областью действия (т. е. всех переменных в сети). Однако заменить его правильным словарем было непросто.

Из возможных режимов области действия, которые init_from_checkpoint accepts, мне пришлось использовать: 'scope_variable_name': variable, но без использования фактического атрибута variable.name.

variable.name выглядит так: 'some_scope/variable_name:0'. Этот :0 отсутствует в имени контрольной точки, поэтому использование scopes = {v.name:v.name for v in variables_to_restore} вызовет ошибку "Переменная не найдена".

Хитрость, чтобы заставить его работать, заключалась в том, чтобы удалить тензорный индекс из имени:

tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 
                              {v.name.split(':')[0]: v for v in variables_to_restore})
person GPhilo    schedule 19.12.2017

Я обнаружил, что {s+'/':s+'/' for s in scopes} не работает, только потому, что variables_to_restore включает что-то вроде "global_step", поэтому области охвата включают глобальные области, которые могут включать все. Вам нужно распечатать variables_to_restore, найти "global_step" вещь и положить ее в "exclude".

person hai lee    schedule 12.06.2018