Как построить глубокую сеть с таким небольшим количеством образцов для каждого класса?

В последние годы глубокое обучение стало довольно популярным для задач распознавания и классификации изображений из-за его высокой производительности. Однако традиционные подходы к глубокому обучению обычно требуют большого набора данных для обучения модели, чтобы различать очень мало разных классов, что кардинально отличается от того, как люди могут учиться даже на очень немногих примерах.

Немногочисленное или одноразовое обучение - это проблема категоризации, которая направлена ​​на классификацию объектов по ограниченному количеству выборок с конечной целью создания алгоритма обучения, более похожего на человека. В этой статье мы рассмотрим подходы глубокого обучения к решению задачи однократного обучения с использованием специальной сетевой структуры: сиамской сети. Мы построим сеть с помощью PyTorch и протестируем ее на рукописном наборе данных символов Omniglot, а также проведем несколько экспериментов для сравнения результатов различных сетевых структур и гиперпараметров, используя метрику оценки однократного обучения.

Набор данных Omniglot

Набор данных рукописных символов Omniglot - это набор данных для однократного обучения, предложенный Lake et al. Он содержит 1623 различных рукописных символа из 50 различных серий алфавитов, каждый из которых был написан от руки 20 разными людьми. Каждое изображение имеет размер 105x105 пикселей. 50 алфавитов разделены в соотношении 30:20 для обучения и тестирования, что означает, что тестовый набор состоит из совершенно нового набора символов, которые ранее не были видны.

Вычислительная среда

Обучение и эксперимент проводились исключительно с помощью Google Colab с использованием ряда графических процессоров, включая Tesla K80 и P100. Мы использовали библиотеки, включая Numpy, Matplotlib и PyTorch.

Метод

Традиционные глубокие сети обычно не работают с однократным или несколькими выстрелами, так как очень мало образцов в классе с большой вероятностью вызовут переоснащение. Чтобы предотвратить проблему переобучения и распространить ее на невидимых персонажей, мы предложили использовать сиамскую сеть.

На рис. 1. представлена ​​базовая архитектура сверточной сиамской сети. В отличие от традиционных CNN, которые принимают на входе 1 изображение для генерации горячего вектора, предлагающего категорию, к которой принадлежит изображение, сиамская сеть принимает 2 изображения и передает их в 2 CNN с той же структурой. Выходные данные будут объединены вместе, в данном случае из-за их абсолютных различий, и переданы на полностью связанные слои для вывода одного числа, представляющего сходство двух изображений. Большее число означает, что два изображения более похожи.

Вместо того, чтобы узнавать, какое изображение принадлежит к какому классу, сиамская сеть учится определять «сходство» между двумя изображениями. После обучения, получив полностью новое изображение, сеть могла затем сравнить изображение с изображением из каждой из категорий и определить, какая категория наиболее похожа на данное изображение.

Предварительная обработка и создание набора данных

Загрузчик данных обучения и проверки

Чтобы обучить сиамскую сеть, мы должны сначала сгенерировать правильные входные данные (попарно) и определить наземную метку истинности для модели.

Сначала мы определяем два изображения, которые принадлежат одному и тому же символу в одном алфавите, чтобы иметь сходство 1 и 0 в противном случае, как показано на рисунке 3. После этого мы случайным образом выбираем пару изображений для ввода в сеть на основе четности index на итерации загрузчика данных. Другими словами, если текущая итерация - нечетное число, мы получаем пару изображений одного и того же символа, и наоборот. Это гарантирует, что наш обучающий набор данных сбалансирован для обоих типов выходных данных. Оба изображения проходят через одно и то же преобразование изображения, поскольку цель состоит в том, чтобы определить сходство двух изображений, поэтому вводить их в разные преобразования изображений не имеет смысла.

Ниже приведен код для создания обучающего набора:

Мы создали 10000 пар этих данных в качестве обучающего набора, который затем случайным образом разделяется на обучение и проверку с соотношением 80:20.

Тестовый загрузчик

Оценка производительности сети при однократном обучении может быть выполнена с помощью n-образных показателей оценки однократного обучения, где мы находим n изображений, представляющих n категорий, и одно основное изображение, которое принадлежит одной из n категорий. Для нашей сиамской сети мы вычислили сходство основных изображений со всеми n изображениями, и пара с наибольшим сходством означает, что основное изображение принадлежит классу.

Загрузчик теста был структурирован таким образом, чтобы поддерживать вышеупомянутую оценку, где берется случайное основное изображение и также извлекаются n изображений, представляющих n категорий, одно из которых является из той же категории основного изображения.

Ниже приведен код для создания набора тестов:

Для нашего финального тестирования мы расширили нашу сеть до 4-этапного однократного обучения с размером тестового набора 1000 и 20-этапного с размером 200.

Эксперимент

Эксперимент 1. Традиционная сиамская сеть для однократного обучения

Основная часть сиамской сети - это двойная сверточная архитектура, которая была показана ранее. Первая сверточная архитектура, которую мы попытаемся построить, принадлежит Koch et al. в своей статье «Сиамские нейронные сети для одноразового распознавания изображений», как показано на рисунке 2. Следует отметить, что после сглаживания абсолютные различия между двумя сверточными ветвями передаются на полностью связанный слой, а не только на один ввод изображения.

Сеть в PyTorch построена следующим образом:

и мы можем выполнять обучение со следующей функцией:

Настройка гиперпараметров

Размер пакета: поскольку мы изучаем, насколько похожи два изображения, размер пакета должен быть довольно большим, чтобы модель была универсальной, особенно для такого набора данных, как этот, с множеством разных категорий. Поэтому мы использовали размер партии 128.

Скорость обучения: мы протестировали несколько скоростей обучения от 0,001 до 0,0005 и выбрали 0,0006, который обеспечил наилучшую скорость уменьшения потерь.

Оптимизатор и потери: мы использовали традиционный оптимизатор Адама для этой сети с потерями двоичной кросс-энтропии (BCE) с логитами.

Результаты

Сеть обучена 30 эпох. На рисунке 3 показан график потерь при обучении и проверке после каждой эпохи, который, как мы можем видеть, показывает резкое снижение и сближение к концу. Потеря валидации обычно уменьшается вместе с потерей обучения, что указывает на то, что за время обучения не произошло переобучения. Во время обучения будет сохранена модель с наименьшими потерями при проверке. Мы использовали потерю проверки вместо потери обучения, поскольку это показатель того, что модель хорошо работает не только на обучающем наборе, что, вероятно, является случаем переобучения.

Эксперимент 2. Добавление пакетной нормализации

Чтобы еще больше улучшить сеть, мы можем добавить пакетную нормализацию, которая предположительно сделает процесс конвергенции более быстрым и стабильным. На рисунке 4 представлена ​​обновленная архитектура с BatchNorm2d после каждого сверточного слоя.

Результаты

Как и ожидалось, потери уменьшились намного быстрее как для потерь при обучении, так и для потерь при проверке по сравнению с исходной сетью. Получив лучший результат, мы решили также обучить модель для большего количества эпох, чтобы увидеть, будет ли она работать лучше, чем эксперимент 1.

Как показано на графике потерь, результаты были немного лучше, чем исходный результат эксперимента 1. Поскольку потери медленно сходятся между эпохами 40 и 50, мы прекратили обучение на 50-й эпохе. На данный момент это лучший результат, которого мы достигли.

Эксперимент 3. Замена ConvNet на облегченный VGG16

После того, как исходная сеть заработала достаточно хорошо, мы также можем протестировать различные хорошо зарекомендовавшие себя CNN для нашей сиамской сети и посмотреть, сможем ли мы достичь лучших результатов. С маленьким размером изображения 105x105 мы хотели использовать сеть, которая сравнительно меньше, с небольшим количеством слоев, но все же давала достойные результаты, и поэтому мы позаимствовали сетевую архитектуру VGG16.

Исходный VGG16 все еще был слишком большим для нашего размера, где последние 5 сверточных слоев имели дело только с отдельными пикселями, и поэтому мы удалили их, получив в итоге следующую сеть:

Результаты

Как показано на графике потерь, потери при обучении уменьшаются значительно медленнее, чем в предыдущих экспериментах. Это может быть связано с тем, что размер ядра сверточных слоев довольно мал (3x3), что дает небольшое принимающее поле. Для задачи вычисления сходства между двумя изображениями, возможно, может быть полезно взглянуть на «большую картину» двух изображений вместо того, чтобы сосредотачиваться на мелких деталях, и, следовательно, большее воспринимающее поле, предложенное в исходной сети, сработало лучше.

Оценка модели

Код для оценки сети реализован следующим образом:

Одноразовое обучение в четырех направлениях

Сначала мы протестировали четырехстороннее однократное обучение с использованием совершенно нового набора изображений для оценки, где все тестовые изображения не использовались во время обучения, и модели также не были известны персонажи. Результаты показали точность примерно 90%, что говорит о том, что модель довольно хорошо обобщалась на невидимые наборы данных и категории, достигая нашей цели - однократного обучения на наборе данных Omniglot.

Одноразовое обучение из 20 способов

После этого мы провели оценку обучения за один прием с 20 участниками для 200 подходов. Где результат вернулся, чтобы по-прежнему составлять около 86%. Мы сравнили результаты с исходными данными, предоставленными Lake et al:

Хотя мы не превзошли или не повторили предложенную точность статьи, которая составляет 92% (возможно, из-за таких деталей, как разная скорость обучения слоев), на самом деле мы были довольно близки к ней.

Кроме того, наша модель на самом деле работает намного лучше, чем многие другие модели, включая обычную сиамскую сеть и ближайшего соседа.

Заключение

Вот и все! Вот как построить сверточную сиамскую сеть для однократного изучения набора данных Omniglot. Полный код также размещен на Github в следующем каталоге:



Спасибо, что зашли так далеко 🙏! Я буду публиковать больше сообщений о различных областях компьютерного зрения / глубокого обучения, обязательно ознакомьтесь с другой моей статьей о 3D-реконструкции!