Как подсчитать общее количество обучаемых параметров в модели тензорного потока, определенной с помощью графа, загруженного из файла .pb?

Я хочу подсчитать параметры в модели тензорного потока. Это похоже на существующий вопрос следующим образом.

Как подсчитать общее количество обучаемых параметров в модели тензорного потока?

Но если модель определяется с помощью графа, загруженного из файла .pb, все предложенные ответы не работают. В основном я загрузил график следующей функцией.

def load_graph(model_file):

  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())

  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

Одним из примеров является загрузка файла Frozen_graph.pb для переобучения в tensorflow-for-poets-2.

https://github.com/googlecodelabs/tensorflow-for-poets-2


person Yanjun    schedule 03.05.2018    source источник
comment
Я действительно не понимаю, как ответы на другой вопрос не работают. Когда у вас есть график, вам просто нужно получить обучаемые переменные этого конкретного графика. Что вы пробовали после вызова этой функции? Можете ли вы предоставить образец файла .pbtxt, который воспроизводит проблему?   -  person E_net4 the curator    schedule 03.05.2018


Ответы (1)


Насколько я понимаю, у GraphDef недостаточно информации для описания Variables. Как объясняется здесь, вам понадобится MetaGraph, который содержит как GraphDef, так и CollectionDef, что является карта, которая может описать Variables. Таким образом, следующий код должен дать нам правильное количество обучаемых переменных.

Экспорт метаграфа:

import tensorflow as tf

a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, r'.\test')

Импортируйте MetaGraph и подсчитайте общее количество обучаемых параметров.

import tensorflow as tf

saver = tf.train.import_meta_graph('test.meta')

with tf.Session() as sess:
    saver.restore(sess, 'test')

total_parameters = 0
for variable in tf.trainable_variables():
    total_parameters += 1
print(total_parameters)
person Y. Luo    schedule 03.05.2018
comment
Означает ли это, что файл .pb не содержит обучаемых_переменных? Большое спасибо. - person Yanjun; 04.05.2018
comment
@Yanjun Я думаю, что все узлы есть. Но вы не можете сказать, какая из них обучаемая_переменная. Или их нет в tf.trainable_variables() после загрузки. - person Y. Luo; 04.05.2018