VGG, потеря восприятия в керасе

Мне интересно, можно ли добавить пользовательскую модель к функции потерь в keras. Например:

def model_loss(y_true, y_pred):
    inp = Input(shape=(128, 128, 1))
    x = Dense(2)(inp)
    x = Flatten()(x)

    model = Model(inputs=[inp], outputs=[x])
    a = model(y_pred)
    b = model(y_true)

    # calculate MSE
    mse = K.mean(K.square(a - b))
    return mse

Это упрощенный пример. На самом деле я буду использовать сеть VGG в проигрыше, поэтому просто пытаюсь понять механику keras.


person William Falcon    schedule 11.05.2017    source источник
comment
Ты это пробовал? Звучит просто, попробуйте. Но я предлагаю создать модель вне функции потерь. Ваша функция потерь должна начинаться со строки a=model(y_pred).   -  person Daniel Möller    schedule 11.05.2017
comment
Но: вы ожидаете, что эта маленькая модель будет обучаться вместе с моделью, содержащей функцию потерь?? Тогда я бы сказал, что никак.   -  person Daniel Möller    schedule 11.05.2017
comment
нет, это замороженная модель. это для потери VGG   -  person William Falcon    schedule 11.05.2017


Ответы (1)


Обычный способ сделать это — добавить ваш VGG в конец вашей модели, убедившись, что все его слои имеют trainable=False перед компиляцией.

Затем вы пересчитываете свой Y_train.

Предположим, у вас есть эти модели:

mainModel - the one you want to apply a loss function    
lossModel - the one that is part of the loss function you want   

Создайте новую модель, присоединив одну к другой:

from keras.models import Model

lossOut = lossModel(mainModel.output) #you pass the output of one model to the other

fullModel = Model(mainModel.input,lossOut) #you create a model for training following a certain path in the graph. 

Эта модель будет иметь те же веса, что и mainModel и lossModel, и обучение этой модели повлияет на другие модели.

Перед компиляцией убедитесь, что lossModel не поддается обучению:

lossModel.trainable = False
for l in lossModel.layers:
    l.trainable = False

fullModel.compile(loss='mse',optimizer=....)

Теперь настройте свои данные для обучения:

fullYTrain = lossModel.predict(originalYTrain)

И, наконец, пройти обучение:

fullModel.fit(xTrain, fullYTrain, ....)
person Daniel Möller    schedule 11.05.2017
comment
Благодарю. да, я видел, как это делается в других проектах. есть ли причина, по которой это нельзя добавить к символическому графику в loss_fn? и работать с y_pred, y_true напрямую? - person William Falcon; 12.05.2017