Как сохранить и восстановить модель tf.estimator.Estimator с помощью export_savedmodel?

Недавно я начал использовать Tensorflow и пытаюсь использовать объекты tf.estimator.Estimator. Я хотел бы сделать что-то априори вполне естественное: после обучения моего классификатора, т.е. экземпляра tf.estimator.Estimator (методом train), я хотел бы сохранить его в файл (независимо от расширения), а затем перезагрузить это позже, чтобы предсказать метки для некоторых новых данных. Поскольку официальная документация рекомендует использовать Estimator API, я думаю, что что-то столь же важное должно быть реализовано и задокументировано.

На какой-то другой странице я видел, что это можно сделать export_savedmodel (см. ">официальная документация), но я просто не понимаю документацию. Нет объяснения, как использовать этот метод. Что такое аргумент serving_input_fn? Я никогда не сталкивался с этим в учебнике Создание пользовательских оценщиков или в любом из руководств, которые я читал. Погуглив, я обнаружил, что около года назад оценщики были определены с использованием другого класса (tf.contrib.learn.Estimator), и похоже, что tf.estimator.Estimator повторно использует некоторые из предыдущих API. Но я не нахожу четких объяснений в документации по этому поводу.

Может ли кто-нибудь привести мне пример игрушки? Или объясните мне, как определить/найти этот serving_input_fn?

И как тогда снова загрузить обученный классификатор?

Спасибо за помощь!

Изменить: я обнаружил, что не обязательно использовать export_savemodel для сохранения модели. На самом деле это делается автоматически. Затем, если позже мы определим новый оценщик с тем же аргументом model_dir, он также автоматически восстановит предыдущий оценщик, как объяснено здесь.


person GWa    schedule 13.07.2018    source источник


Ответы (1)


Как вы поняли, оценщик автоматически сохраняет и восстанавливает модель во время обучения. export_savemodel может быть полезен, если вы хотите развернуть свою модель в поле (например, предоставить лучшую модель для обслуживания Tensorflow).

Вот простой пример:

est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)

def serving_input_fn(): inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])} return tf.estimator.export.ServingInputReceiver(inputs, inputs)

По сути, serve_input_fn отвечает за замену конвейеров набора данных заполнителем. В развертывании вы можете передавать данные в этот заполнитель в качестве входных данных для вашей модели для вывода или прогнозирования.

person Omid Sakhi    schedule 13.09.2018