Простое руководство по интерпретации того, что изучает сверточная нейронная сеть с помощью Pytorch

Это третье сообщение в серии руководств по созданию моделей глубокого обучения с помощью Pytorch. Ниже представлена ​​полная серия:

  1. Учебник по Pytorch для начинающих
  2. Понять тензорные размерности в моделях DL
  3. CNN и визуализация функций (этот пост)
  4. Гиперпараметрическая настройка с Optuna
  5. Перекрестная проверка K-сгиба
  6. Сверточный автоэнкодер
  7. Автоэнкодер с шумоподавлением
  8. Вариационный автоэнкодер

Цель этой серии - сделать Pytorch более интуитивно понятным и доступным с помощью примеров реализации. В Интернете есть множество руководств по использованию Pytorch для построения многих типов сложных моделей, но в то же время это может сбивать с толку, потому что при переходе от учебника к другому всегда есть небольшие различия. В этой серии я хочу начать с самых простых тем к более сложным.

Введение

Сверточная нейронная сеть - это особый тип искусственной нейронной сети, широко применяемый для распознавания изображений. Успех этой архитектуры начался в 2015 году, когда благодаря этому подходу была решена задача классификации изображений ImageNet.

Как вы, наверное, знаете, эти методы очень эффективны и хороши для прогнозирования, но в то же время их трудно интерпретировать. По этой причине их еще называют моделями черного ящика.

Безусловно, существуют методы, не зависящие от модели, такие как LIME и графики частичной зависимости, которые можно применять в любой модели. Но в этом случае имеет смысл применить интерпретируемые методы, разработанные специально для нейронных сетей. В отличие от моделей машинного обучения, сверточные нейронные сети изучают абстрактные функции из необработанных пикселей изображения [1].

В этом посте я сосредоточусь на том, как сверточные нейронные сети изучают функции. Это возможно благодаря пошаговой визуализации изученных функций. Прежде чем показывать реализации с Pythorch, я объясню, как работает CNN, а затем визуализирую карты функций и поля Receptive, изученные CNN, обученной задаче классификации.

Table of Content:
1. What is CNN
2. Define and train CNN on MNIST
3. Evaluate model on test set
4. Visualize Filters
5. Visualize Feature Maps

1. Что такое CNN?

CNN состоят из строительных блоков: сверточных слоев, слоев объединения и полносвязных слоев. Основная функция сверточного слоя - извлечение объектов или так называемых карт объектов. Как он может это сделать? Он использует несколько фильтров из набора данных [2].

После этого размерность карт признаков в результате операции свертки уменьшается на слой объединения. Наиболее часто используемая операция объединения - это Maxpooling, которая выбирает наиболее значимое значение пикселя в каждом фрагменте фильтра карты функций. Итак, эти два типа слоев полезны для извлечения признаков.

В отличие от сверточных и объединяющих слоев, полностью связанный слой отображает извлеченные объекты в окончательный результат, например, классифицируя изображение MNIST как одну из 10 цифр.

В цифровых изображениях двумерная сетка хранит значения пикселей. Его можно рассматривать как массив чисел. Ядро, представляющее собой небольшую сетку, обычно размером 3x3, применяется в каждой позиции изображения. По мере того, как вы погружаетесь в более глубокие слои, функции становятся все более и более сложными.

Производительность модели достигается с помощью функции потерь, которая представляет собой разницу между выходными данными и целевыми метками, посредством прямого распространения в обучающем наборе, а параметры, такие как веса и смещение, обновляются посредством обратного распространения с помощью алгоритма градиентного спуска.

2. Определите и обучите CNN на наборе данных MNIST.

Давайте сначала импортируем библиотеки и набор данных. torch - это модуль, который предоставляет структуры данных для тензоров с одним или несколькими измерениями.

Наиболее важные библиотеки:

  • torchvision состоит из популярных наборов данных, известных архитектур моделей и стандартных преобразований изображений. В нашем случае он предоставляет нам набор данных MNIST.
  • torch.nn содержит классы и функции, которые помогут вам построить сверточную нейронную сеть.
  • torch.optim предоставляет все оптимизаторы, такие как Adam.
  • torch.nn.functional используется для импорта функций, таких как выпадение, свертка, объединение, нелинейные функции активации и функции потерь.

Мы загружаем обучающие и тестовые наборы данных и преобразуем наборы данных изображений в Tensor. Нам не нужно нормализовать изображения, потому что наборы данных уже содержат изображения в оттенках серого. После того, как мы разделим обучающий набор данных на обучающий и проверочный наборы, random_split предоставит случайное разделение для этих двух наборов. DataLoader используется для создания загрузчиков данных для наборов для обучения, проверки и тестирования, которые разделены на мини-пакеты. batchsize - это количество образцов, используемых в одной итерации во время обучения модели.

Мы определяем архитектуру CNN.

Мы можем легко распечатать CNN, чтобы иметь быстрый обзор:

Вы можете видеть, что есть два сверточных слоя и два полностью связанных слоя. За каждым сверточным слоем следует функция активации ReLU и слой maxpooling. Функция view преобразует данные в одномерный массив, который будет передан на линейный уровень. Второй полностью связанный слой, также называемый выходным слоем, классифицирует изображение как одну из 10 цифр.

Мы определяем строительные блоки, которые будут использоваться для обучения CNN:

  • torch.device для обучения модели с помощью аппаратного ускорителя, такого как графический процессор
  • Сеть CNN, которая будет перенесена на устройство
  • Потеря кросс-энтропии и оптимизатор Адама

Теперь мы можем обучить сеть на обучающем наборе и оценить его на проверочном наборе:

Обучающий код можно разбить на две части.

Прямое распространение:

  1. Мы передаем входные изображения в сеть с model(images)
  2. Потери вычисляются путем вызова criterion(outputs,labels), где выходные данные составляют прогнозируемый класс, а метки составляют целевой класс.

Обратное распространение:

3. Градиент очищен, чтобы другие значения не накапливались с optimizer.zero_grad().

4. loss.backward() используется для выполнения обратного распространения и вычисляет градиент на основе потерь.

5. optimizer.step() всегда после вычисления градиента. Он перебирает все параметры и обновляет их значения.

Функция потерь и точность рассчитываются как для обучающего, так и для проверочного набора.

3. Оцените модель на тестовом наборе.

После обучения модели мы можем оценить производительность на тестовом наборе:

Разобьем тестовый код на маленькие кусочки:

  • torch.no_grad() используется для отключения отслеживания градиента, нам больше не нужно вычислять градиенты, поскольку модель уже обучена
  • передать входные изображения в сеть
  • рассчитать тестовый проигрыш, добавив loss.item()*images.size(0)
  • рассчитать точность теста, добавив (predicted==labels).sum().item()

4. Визуализируйте фильтры

Мы можем визуализировать изученные фильтры, используемые CNN для свертки карт функций, содержащих извлеченные функции из предыдущего слоя. Эти фильтры можно получить, перебирая все слои моделей, list(model.children()). Если слой сверточный, мы можем сохранить вес в списке model_weights, который будет содержать фильтры, используемые в двух сверточных слоях.

Ниже я показываю форму найденных фильтров.

Теперь мы можем, наконец, визуализировать изученные фильтры первого сверточного слоя:

Теперь настала очередь визуализировать фильтры второго сверточного слоя.

5. Визуализируйте карты функций

Карта функций, также называемая картой активации, получается с помощью операции свертки, применяемой к входным данным с помощью фильтра / ядра. Ниже мы определяем функцию для извлечения функций, полученных после применения функции активации.

Из набора обучающих данных мы берем изображение, представляющее цифру 9. Итак, мы визуализируем карты характеристик, полученные для этого изображения, в первом сверточном слое.

Теперь мы визуализируем карты характеристик, полученные для того же изображения во втором сверточном слое.

Заключительная мысль:

Поздравляю! Вы научились визуализировать изученные функции CNN с помощью Pytorch. Сеть изучает новые и все более сложные функции в своих сверточных слоях. От первого сверточного слоя до второго сверточного слоя вы можете увидеть различия в этих функциях. Чем дальше вы продвигаетесь по сверточным слоям, тем более абстрактными будут объекты. Код Github находится здесь. Спасибо за прочтение. Хорошего дня.

Использованная литература:

[1] https://christophm.github.io/interpretable-ml-book/cnn-features.html#feature-visualization

[2] https://insightsimaging.springeropen.com/articles/10.1007/s13244-018-0639-9

Вам понравилась статья? Станьте участником и получайте неограниченный доступ к новым сообщениям о данных каждый день!