Настройка гиперпараметров GAN

введите здесь описание изображения  введите описание изображения здесь

Как показано на двух изображениях выше, во время обучения модели DCGAN градиент не является стабильным и сильно колеблется, по этой причине модель не может нарисовать идеальное изображение, даже чтобы нарисовать изображение, распознаваемое человеческим глазом. Кто-нибудь может сказать мне, как настроить параметр, такой как процент отсева, скорость обучения или что-то еще, чтобы модель работала лучше? Я буду вам очень благодарен! Вот модель, которую я сделал раньше (Build with Keras):

дискриминатор:

скорость обучения 0,0005

процент отсева составляет 0,6

batch_size - 25

dis=Sequential()

dis.add(Conv2D(depth*1, 5, strides=2, input_shape=(56,56,3),padding='same',kernel_initializer='RandomNormal', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Conv2D(depth*2, 5, strides=2, padding='same',kernel_initializer='RandomNormal', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Conv2D(depth*4, 5, strides=2, padding='same',kernel_initializer='RandomNormal', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Conv2D(depth*8,5,strides=1,padding='same',kernel_initializer='RandomUniform', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Flatten())

dis.add(Dense(1))

dis.add(Activation('sigmoid'))

dis.summary()

dis.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=d_lr))

генератор и модель GAN:

скорость обучения 0,0001

импульс 0,9

gen=Sequential()

gen.add(Dense(dim*dim*dep,input_dim=100))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(Reshape((dim,dim,dep)))

gen.add(Dropout(dropout))

gen.add(UpSampling2D())

gen.add(Conv2DTranspose(int(dep/2),5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(UpSampling2D())

gen.add(Conv2DTranspose(int(dep/4),5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(UpSampling2D())

gen.add(Conv2DTranspose(int(dep/8),5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(Conv2DTranspose(3,5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(Activation('sigmoid'))

gen.summary()


GAN=Sequential()

GAN.add(gen)

GAN.add(dis)

GAN.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=g_lr))

person Ezreal    schedule 24.09.2017    source источник
comment
Дискриминатор вроде исправен, но не генератор. Я думал, вы можете тренировать генератор дольше и снизить скорость обучения.   -  person Lerner Zhang    schedule 15.01.2018


Ответы (1)


Стабильное обучение GAN - это открытая исследовательская проблема. Тем не менее я могу дать вам два совета. Если вы придерживаетесь исходной программы обучения GAN и не имеете абсолютных знаний о том, что делаете, используйте архитектуру DCGAN с доступными гиперпараметрами, как описано в их статье (https://arxiv.org/pdf/1511.06434.pdf%C3%AF%C2%BC%E2%80%B0). Обучение GAN очень нестабильно, и использование других гиперпараметров приведет к коллапсу режима или исчезновению градиентов.

Более простой путь с GAN - использовать Wasserstein GAN. Они довольно стабильны при использовании нестандартной архитектуры. Однако я настоятельно рекомендую использовать гиперпараметр, предложенный в их статье, потому что для меня обучение также рухнуло для разных гиперпараметров. Улучшенный Wasserstein GAN: [https://arxiv.org/pdf/1704.00028.pdf]

person Thomas Pinetz    schedule 24.09.2017
comment
Спасибо, вместо этого я попробую Wasserstein GAN. - person Ezreal; 25.09.2017