В этой статье я приведу практический пример (с кодом) того, как можно использовать популярный фреймворк PyTorch для применения Vision Transformer, который был предложен в статье « Изображение стоит 16x16 слов: преобразователи для распознавания изображений в масштабе » (который я рассмотрел в другом посте), к практической задаче компьютерного зрения.

Для этого рассмотрим проблему распознавания рукописных цифр с помощью известного MNIST dataset.

Я хотел бы сразу сделать оговорку, просто чтобы было понятно. Я выбрал набор данных MNIST для этой демонстрации, потому что он достаточно прост, чтобы модель можно было обучить на нем с нуля и использовать для прогнозирования без какого-либо специализированного оборудования в течение нескольких минут, а не часов или дней, так что буквально любой, у кого есть компьютер, может это сделать и увидеть как это работает. Я не особо старался оптимизировать гиперпараметры модели, и у меня определенно не было цели достичь современной точности (в настоящее время около 99,8% для этого набора данных) с помощью этого подхода.

Фактически, хотя я покажу, что Vision Transformer может достигать респектабельной точности 98% + по MNIST, можно утверждать, что это не лучший инструмент для этой работы. Поскольку каждое изображение в этом наборе данных небольшое (всего 28x28 пикселей) и состоит из одного объекта, применение глобального внимания может иметь лишь ограниченную пользу. Я мог бы написать еще один пост позже, чтобы изучить, как эту модель можно использовать на более крупном наборе данных с большими изображениями и большим разнообразием классов. А пока я просто хочу показать, как это работает.

Что касается реализации, я буду полагаться на код из репозитория с открытым исходным кодом this Фила Ванга, в частности на следующий класс Vision Transformer (ViT) из vit_pytorch.py файл:

Как и любой класс модуля нейронной сети PyTorch, он имеет функцию инициализации (__init__), в которой определены все обучаемые параметры и уровни, и функцию вперед, которая устанавливает путь эти уровни объединены в общую архитектуру сети.

Для краткости здесь дается определение только самого класса ViT, без зависимых классов. Если вы хотите использовать этот код на своем компьютере, вам нужно будет импортировать весь файл vit_pytorch.py ​​ (который на удивление мал, всего около сотни строк code; я даю ссылку на мою собственную форкованную версию на GitHub на тот случай, если исходный файл изменится в будущем), а также на последнюю версию PyTorch (я использовал 1.6.0) и библиотека einops, используемая для манипуляций с тензором.

Чтобы начать использовать набор данных MNIST, нам нужно сначала загрузить его, что мы можем сделать следующим образом (с этого момента в посте весь код мой, хотя многие из них вполне стандартные):

Преобразование transform_mnist в приведенном выше коде используется для нормализации данных изображения до нулевого среднего и стандартного отклонения, равного 1, что, как известно, облегчает обучение нейронной сети. Объекты train_loader и test_loader содержат изображения MNIST, уже случайным образом разделенные на пакеты, чтобы их можно было удобно использовать в процедурах обучения и проверки.

Каждый элемент в наборе данных содержит изображение с соответствующей меткой наземной действительности. Цель нашего Transformer после обучения на обучающей части набора данных (60 000 рукописных изображений цифр) будет заключаться в том, чтобы на основе изображения предсказать правильную метку для каждого образца в тестовой части (10 000 изображений).

Мы будем использовать следующую функцию для обучения нашей модели для каждой эпохи:

Функция перебирает каждый пакет в объекте data_loader. Для каждого пакета он вычисляет выходные данные модели (как log_softmax) и отрицательную потерю логарифмической вероятности для этих выходных данных, а затем вычисляет градиенты эта потеря касается каждого параметра обучаемой модели с помощью loss.backward () и обновляет параметры с помощью optimizer.step (). Каждый сотый пакет предоставляет распечатанное обновление на прогресс обучения и добавляет значение текущего проигрыша в список loss_history.

После каждой эпохи обучения мы сможем увидеть, как наша текущая модель работает на тестовом наборе, используя следующую функцию:

Хотя это похоже на описанную выше процедуру обучения, теперь мы не вычисляем никаких градиентов, а вместо этого просто сравниваем выходные данные модели с наземными метками истинности, чтобы вычислить точность и обновить историю потерь.

Когда все функции определены, пора инициализировать нашу модель и запустить обучение. Мы будем использовать следующий код:

Здесь мы определяем нашу модель Vision Transformer с размером фрагмента 7x7 (что для изображения 28x28 будет означать 4 x 4 = 16 фрагментов на изображение), 10 возможных целевых классов (от 0 до 9) и 1 цветовой канал (поскольку изображения являются оттенки серого).

Что касается сетевых параметров, мы используем размер встраивания в 64 единицы, глубину в 6 трансформаторных блоков, 8 трансформаторных головок и 128 единиц в скрытом слое выходной головки MLP. В качестве оптимизатора мы будем использовать Adam (как в статье) со скоростью обучения 0,003. Обучим нашу модель на 25 эпох и посмотрим на результаты.

Нет особого оправдания для использования значений гиперпараметров, указанных выше. Я просто выбрал что-то разумное. Конечно, возможно, что их оптимизация приведет к более высокой точности и / или более быстрой сходимости.

После запуска кода в течение 25 эпох (в обычном бесплатном ноутбуке Google Colab с графическим процессором Tesla T4) он выдал следующий результат:

Что ж, точность 98,36% - это неплохо. Это лучше, чем то, что можно было бы ожидать от полностью подключенной сети (где я получаю около 97,8–97,9% без каких-либо уловок), поэтому, безусловно, есть выгода от уровней внимания. Конечно, как я уже упоминал выше, Vision Transformer не особенно подходит для этой задачи, и даже простая сверточная сеть с несколькими уровнями может достичь точности не ниже 99%. Возможно, эта модель Transformer может работать немного лучше после оптимизации гиперпараметров.

Но дело в том, что это работает, и, как было описано в статье, в применении к более крупным и сложным задачам, он может конкурировать с лучшими сверточными моделями. Надеюсь, это краткое руководство показало, как вы, читатель, можете использовать его в своей работе.