Отключает ли tflearn.models.dnn.DNN автоматически выпадающие слои и нормализацию пакетов при прогнозировании?

Я новичок в нейронных сетях, поэтому я решил использовать Tflearn, потому что он интуитивно понятен. Однако ответа на свой вопрос я не нашел. В документации tflearn приводится следующий пример того, как глубокая нейронная сеть может что-то предсказывать:

network = ...
model = DNN(network)
model.load('model.tflearn')
model.predict(X)

Я вставил в сеть несколько слоев пакетной нормализации, потому что моя модель выглядела переобученной. Будет ли model.predict() автоматически «сообщать» слою пакетной нормализации, чтобы он не вел себя как на этапе обучения? Или мне нужно как-то указать это с помощью tflearn.config.is_training (is_training=False, session=None)?

Если да, то знаете ли вы, где я должен поставить эту строку? И как мне создать свою сессию, чтобы она работала так же, как мой код. На данный момент это в основном выглядит как пример на tflearn.org:

net = tflearn.input_data(shape=[None, 784])
net = tflearn.fully_connected(net, 64)
net = tflearn.dropout(net, 0.5)
net = tflearn.fully_connected(net, 10, activation='softmax')
net = tflearn.regression(net, optimizer='adam', 
loss='categorical_crossentropy')

model = tflearn.DNN(net)
model.fit(X, Y)

за исключением того, что я использую слой пакетной нормализации и нейронную сеть для аппроксимации функции. К сожалению, я не могу опубликовать свой код прямо сейчас, так как он находится на другом компьютере, но в основном он такой же.

Может ли кто-нибудь помочь мне с этим вопросом?

Заранее спасибо!


person uhu123    schedule 09.07.2018    source источник


Ответы (1)


Вам нужно установить tflearn.is_training в True или False, когда вы тренируетесь и прогнозируете, а tflearn позаботится обо всем остальном. Как только вы определите свою модель, вы можете обучить ее следующим образом:

with tf.Session() as sess:
    tflearn.is_training(True, session=sess)
 model.fit(X, Y)

а затем предсказать, используя:

with tf.Session as sess:
    tflearn.is_training(False, session=sess)
model.predict(X)
person Paul W    schedule 10.07.2018