Резюме: я пытаюсь переобучить простую CNN для MNIST без использования высокоуровневого API. Мне уже удалось это сделать, переобучив всю сеть, но моя текущая цель - переобучить только последние один или два полностью подключенных уровня.
На данный момент работают: Допустим, у меня есть CNN со следующей структурой
- Сверточный слой
- RELU
- Уровень объединения
- Сверточный слой
- RELU
- Уровень объединения
- Полностью связанный слой
- RELU
- Слой исключения
- Полностью подключенный слой до 10 выходных классов
Моя цель - переобучить либо последний полностью подключенный слой, либо последние два полностью подключенных слоя.
Пример сверточного слоя:
W_conv1 = tf.get_variable("W", [5, 5, 1, 32],
initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0 / 784)))
b_conv1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[32]))
z = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
z += b_conv1
h_conv1 = tf.nn.relu(z + b_conv1)
Пример полностью подключенного слоя:
input_size = 7 * 7 * 64
W_fc1 = tf.get_variable("W", [input_size, 1024], initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0/input_size)))
b_fc1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[1024]))
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
Мое предположение: при выполнении обратного распространения ошибки для нового набора данных я просто убеждаюсь, что мои веса W и b (из W * x + b) зафиксированы в не полностью связанных слоях.
Первая мысль о том, как это сделать: сохраните W и b, выполните шаг обратного распространения ошибки и замените новые W и b старыми в слоях, которые я не хочу менять.
Мои мысли об этом первом подходе:
- Это требует больших вычислительных ресурсов и расходует память. Все преимущество выполнения только последнего слоя состоит в том, что не нужно делать остальные.
- Обратное распространение может работать по-другому, если не применяется ко всем слоям?
Мой вопрос:
- Как правильно переобучить определенные слои в нейронной сети, если не используются высокоуровневые API. Приветствуются как концептуальные, так и программные ответы.
P.S. Полностью осознаю, как это можно сделать с помощью высокоуровневых API. Пример: https://towardsdatascience.com/how-to-train-your-model-dramatically-faster-9ad063f0f718. Просто не хочу, чтобы нейронные сети были волшебством, я хочу знать, что на самом деле происходит