как загрузить веса предварительно обученных моделей LSTM в Tensorflow

Я хочу реализовать модель LSTM с предварительно обученными весами в Tensorflow. Эти веса могут поступать от Caffee или Torch.
Я обнаружил, что в файле rnn_cell.py есть ячейки LSTM, такие как rnn_cell.BasicLSTMCell и rnn_cell.MultiRNNCell. Но как я могу загрузить предварительно обученные веса для этих ячеек LSTM.


person Zhiqiang Wan    schedule 25.06.2016    source источник


Ответы (1)


Вот решение для загрузки предварительно обученной модели Caffe. См. полный код здесь, на который есть ссылка в обсуждении в в этой беседе.

net_caffe = caffe.Net(prototxt, caffemodel, caffe.TEST)
caffe_layers = {}

for i, layer in enumerate(net_caffe.layers):
    layer_name = net_caffe._layer_names[i]
    caffe_layers[layer_name] = layer

def caffe_weights(layer_name):
    layer = caffe_layers[layer_name]
    return layer.blobs[0].data

def caffe_bias(layer_name):
    layer = caffe_layers[layer_name]
    return layer.blobs[1].data

#tensorflow uses [filter_height, filter_width, in_channels, out_channels] 2-3-1-0 
#caffe uses [out_channels, in_channels, filter_height, filter_width] 0-1-2-3
def caffe2tf_filter(name):
    f = caffe_weights(name)
    return f.transpose((2, 3, 1, 0))

class ModelFromCaffe():
    def get_conv_filter(self, name):
        w = caffe2tf_filter(name)
        return tf.constant(w, dtype=tf.float32, name="filter")

    def get_bias(self, name):
        b = caffe_bias(name)
        return tf.constant(b, dtype=tf.float32, name="bias")

    def get_fc_weight(self, name):
        cw = caffe_weights(name)
        if name == "fc6":
            assert cw.shape == (4096, 25088)
            cw = cw.reshape((4096, 512, 7, 7)) 
            cw = cw.transpose((2, 3, 1, 0))
            cw = cw.reshape(25088, 4096)
        else:
            cw = cw.transpose((1, 0))

        return tf.constant(cw, dtype=tf.float32, name="weight")

images = tf.placeholder("float", [None, 224, 224, 3], name="images")
m = ModelFromCaffe()

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  batch = cat.reshape((1, 224, 224, 3))
  out = sess.run([m.prob, m.relu1_1, m.pool5, m.fc6], feed_dict={ images: batch })
...
person ssjadon    schedule 27.06.2016
comment
Большое спасибо за ответ. Это мне очень помогает. Но для RNN я не нашел, как инициализировать предварительно тренированные веса. - person Zhiqiang Wan; 30.06.2016
comment
Вы должны создать переменную, используя класс ModelFromCaffe, например. fc6_W = tf.Variable(m.get_fc_weight("fc6"), name="fc6_W") см. документацию здесь. - person ssjadon; 07.07.2016