Pytorch: Реальная пошаговая реализация CNN на MNIST

Вот краткое руководство о том, как и о преимуществах реализации CNN в PyTorch. Мы перебираем строку за строкой, чтобы вы могли избежать всех ошибок при реализации!

В этой статье мы возьмем на себя задачу реализации сверточной нейронной сети в Pytorch! Я действительно хотел написать на такую ​​тему из-за огромного количества необъяснимых и полных ошибок реализации, которые роятся по всему Интернету и мешают большинству людей быстро начать свои собственные реализации.

Однако обратите внимание, что во время написания я предполагаю, что читатель имеет некоторые базовые знания в области нейронных сетей и CNN, если нет, то просмотрите ссылки в нижней части статьи для лучшего понимания, прежде чем начать.

Статья разделена на эти 5 частей:

  • 1] Как подготовить данные к правильному формату? : MNIST
  • 2] Представление архитектуры CNN: наследование классов
  • 3] Как тренироваться на данной модели?
  • 4] Как построить график сохраненных потерь поездов и потерь при проверке?
  • 5] Как использовать вашу модель, чтобы сделать вывод?

Итак, поехали!

1] Как подготовить ваши данные к правильному формату? : MNIST

Прежде всего, важно знать, что PyTorch имеет свою собственную структуру данных, которая является тензорами. Это очень похоже на массивы NumPy, но не совсем. И, как упоминалось в заголовке, мы будем использовать набор данных MNIST Digit Recognizer, который вы можете найти на Kaggle.

Набор цифр представлен в формате .csv с 784 столбцами (не включая столбец индекса), поэтому сначала нам нужно преобразовать его в изображения, как здесь.

После этого вы сможете построить точки данных и просмотреть данные через изображения. (рисунок 1)

К сожалению, текущий формат данных несовместим с моделью. И что нужно сделать, так это преобразовать данные в тензоры (формат torch). Также обратите внимание, что здесь мы уже сделали предположение о форме входов для нашей NN, мы используем (1,28,28), что означает, что существует только один канал и, следовательно, изображения находятся в оттенках серого, как вы можете видеть на рис.1, вы можете изменить его на (3,28,28), но тогда вам придется изменить исходную форму ввода. Опять же, модель требует, чтобы каждая точка данных имела форму (номер канала, ширина, высота).

Предупреждение: у меня была эта проблема раньше, элементы в обучающем наборе нужно преобразовать в float, а не в long, иначе позже появится ошибка.

И, наконец, чтобы облегчить цикл по наборам данных во время обучения, мы фиксируем batch_size = 100 и подготавливаем 100 точек данных для каждой эпохи. Подготовка осуществляется следующим образом.

2] Представление архитектуры CNN: наследование классов

Мы используем здесь довольно классическую архитектуру, которая изображена здесь (рис.2):

  • Два слоя 2dConvolutoin с размером фильтра свертки (3x3) и выходной фильтрованный массив 16/32
  • Два слоя 2dMaxPool с размером фильтра (2x2)
  • Функция активации Relu во всем
  • Уровень FC из 800 узлов

Скорее всего, архитектура не оптимизирована и не подходит для задачи классификации цифр, но это не является целью данной статьи. Несмотря на это, вот реализация.

Сделаем несколько замечаний по поводу реализации:

Слои свертки (Conv2d) получают в качестве аргументов input_channels и output_channels, которые представляют собой количество отфильтрованных (фильтром 3x3) тензоров соответственно из предыдущего слоя и на текущем слое. Ядро - это размер фильтра, который мы используем в текущем фильтре. Шаг - это шаг сдвига, который вы выполняете для матрицы точек данных, когда вы выполняете умножение точки данных и фильтра на вход. А заполнение - это количество столбцов, которые вы добавляете, когда фильтр накладывается на исходное изображение.

Слои MaxPool (MaxPool2d) получают в качестве аргумента размер ядра, который снова является размером фильтра.

Полностью сверточный слой (линейный) получает в качестве аргумента количество узлов из предыдущего слоя и количество узлов, которые он имеет в настоящее время.

Flatten (out.view (out.size (0), - 1)) просто выравнивает изображения. Это означает, что мы переходим, например, от (8400,1,28,28) к (8400,784).

Теперь о гиперпараметрах, определенных вне класса:

Количество эпох (num_epochs) не требует пояснений.

Функция потерь (ошибка), в нашем случае это потеря перекрестной энтропии.

Скорость обучения (скорость обучения), равная 0,001

Оптимизатор (оптимизатор), в нашем случае это стохастический градиентный спуск.

3] Как тренироваться на данной модели?

Теперь, когда у нас есть модель, нам нужно найти правильные веса для всех параметров этой модели. Вот код, который мы объясним.

За одну итерацию выполняется все следующее:

  • Очищаем предыдущий градиент (zero_grad)
  • Выполняем прямую связь и вычисляем потерю (модель (поезд) и ошибка (выходы, метки))
  • Из потерь вычисляем новые градиенты (.backwards ())
  • И увеличиваем веса (optimizer.step ())

Аналогичные задачи выполняются для тестового набора, чтобы получить потерю валидации во время обучения. Мы также распечатываем подробные сведения каждые 500 эпох.

4] Как построить график сохраненных потерь поездов и потерь при проверке?

С сохраненными значениями, которые мы получили ранее во время обучения. Затем нам просто нужно построить график, используя следующий код.

Из чего мы получаем следующие графики для нашего первого прогона.

5] Как использовать вашу модель, чтобы сделать вывод?

Допустим, вы хотите предсказать одно изображение, используя только что обученную модель. Для этого необходимо вернуть вашему изображению правильную форму ввода для вашей сети (пока это нормально), но затем не забудьте преобразовать его в тензор перед его использованием.

#We load the image
img = X_cv[0] #shape (784,1)
img = img.reshape(1, 1, 28, 28) #shape (1,1,28,28)
img  = torch.from_numpy(img).float() #tensor
#We do the prediction here and we do + 1 because we start from 0
prediction = model(img).detach().numpy()[0].argmax() + 1

Как вы можете видеть в этом примере, способ получения прогноза состоит в том, чтобы сначала выполнить прямую связь с использованием модели с изображением в качестве аргумента. Прямая связь дает нам распределение по 10 меткам (10 цифр), и, таким образом, результат, который мы должны выбрать, является максимальной вероятностью. Прогнозируемая метка тогда является i-м элементом классов классификации с индексом i максимальной вероятности выхода.



Спасибо за чтение и, пожалуйста, рассмотрите возможность подписки на мою среду и мой Github! Вот еще несколько статей, которые могут вас заинтересовать!