Цель
Целью сети был перевод между двумя наборами изображений в разных «доменах». Например, перевод:
- с зимы на лето
- живое действие на анимацию
- зебра на лошадь
Эти изображения не контролируются, что означает, что им не нужно иметь совпадающие пары в каждом домене.
UNET-UNIT
Плюсы:
Другие сети, которые делают это, такие как CycleGAN и UNIT, занимают более 200 эпох для получения результатов высокого качества. Я смог получить хорошие результаты, используя несколько описанных ниже изменений всего за 2 эпохи!
Минусы:
Другие сети могут выполнять более подробные преобразования изображений в изображения, такие как улыбка на хмурый взгляд, от животного к животному. Я не пробовал тренироваться более 60 эпох, но пока эта сеть не может воспроизвести эти высококачественные переводы.
Результаты
- Зима к лету
- Живое действие на анимацию
3. Объяснение
Основным источником вдохновения для этой сети послужила архитектура NVIDIA UNIT, выпущенная в 2017 году (см. Здесь: https://github.com/mingyuliutw/UNIT/), а также UNET от FastAI. Основные изменения, внесенные в АГРЕГАТ:
- Пропускать соединения между уровнями кодировщика и декодера
- Предтренировочный дискриминатор
- PixelShuffle повышающая дискретизация слоев
- Предварительно обученный вырезать resnet34 корпус для энкодеров
- Иногда бывает полезен генератор до и после обучения, чтобы просто восстановить изображения для пары эпох.
Архитектура UNIT-UNET состоит из трех основных частей.
Кодировщики изображений (E1, E2)
- Кодеры изображений состоят из предварительно обученного тела resnet34 с удаленными последними тремя сверточными слоями. После кодировщиков добавлен дополнительный сверточный слой, который используется обеими сетями.
- Вход - это изображение (3224224), а выход - 512-канальный тензор, предназначенный для изображения, закодированного в общее пространство для обоих доменов (по сути, это место, где оба входа домена могут отображаться, и где закодированное изображение может быть отображено на выход любого домена. )
Декодеры изображений (G1, G2)
- Декодеры изображений состоят из слоев с повышенной дискретизацией PixelShuffle вместе с сохраненными слоями кодеров для воспроизведения изображений в целевом домене.
- Входное изображение - 512-канальное закодированное изображение, а выходное - (3 224 224) изображения.
Дискриминатор (D1, D2)
- Дискриминатор - это обычная сверточная сеть с двоичной классификацией, которая различает изображения в обоих изображениях.
Поток
- Во-первых, одно входное изображение из каждого домена пропускается через соответствующие кодеры и сопоставляется с общим скрытым пространством.
- Оба закодированных изображения проходят через оба декодера для получения четырех изображений, которые, как предполагается, будут двумя реконструированными изображениями и двумя переведенными изображениями.
- Затем переведенные изображения снова пропускаются через кодировщики (aToB проходит через B, а bToA проходит через A). Эти два изображения затем переводятся, чтобы вернуться к исходным входам. Теперь это циклические изображения.
- Убыток определяется по:
- Вывод дискриминатора из переведенных изображений
- Сходство входов и реконструированных изображений
- Сходство между входами и зацикленными изображениями.
Выполнение
(Весь код можно найти на Github здесь, написанный на Pytorch с использованием библиотеки FastAI)
DataBunch
Чтобы это работало, мне нужна настраиваемая панель данных, которая будет выводить по одному изображению из каждой категории (X1 и X2). Под капотом это означает, что мне нужен мой один тензор из каждого домена, объединенный вместе каждый раз, когда моя сеть запрашивает элемент.
У FastAI есть отличная документация, объясняющая, как именно это сделать, поэтому я не буду повторять то, что они сказали. Если вас интересует полное объяснение, см .: https://docs.fast.ai/tutorial.itemlist.html
Генератор
Сердце сети - класс генератора. Это класс, который:
- Вводит два изображения (по одному из каждого домена)
- Запускает их через кодировщик в (надеюсь) общее скрытое пространство
- Берет каждое закодированное изображение и проходит через оба декодера A и B для создания двух реконструкций (aToA, bToB) и переводов (aToB, bToA)
Кое-что отметить
- Конец кодировщика, несколько средних сверток и начало декодера являются общими для обоих доменов (оба домена проходят через одни и те же уровни). Это должно способствовать созданию общих представлений обоих изображений, когда они находятся в закодированном состоянии.
UNET
- Этот блок кода сохраняет 3 уровня декодера resnet34 для каждого декодера. Слои сохраняются в self.sfs (A / B) и являются тензорами, в которых размер карты функций изменяется в декодере.
- Затем слои объединяются с промежуточным уровнем в функциях декодера. Декодированные изображения используют то, что было их исходным изображением, в качестве пропуска соединений, т.е. aToA и aToB будут использовать слои из bodyA
Дискриминатор
Дискриминатор - это простой двоичный классификатор, который принимает (3 224 224) изображения и выдает на выходе тензор 1 ранга. Дискриминатор обучен различать два домена.
GAN Wrapper
Поскольку нам нужно иметь возможность управлять обучением генератора отдельно, нам нужна оболочка GAN для управления потоком в сети. Я смог использовать модуль FastAI GAN с некоторыми настройками:
- GAN Loss был переписан с учетом потока в нашей сети (код является прямым переводом диаграммы):
- Я сделал параметр, определяющий количество пакетов генератора и дискриминатора, которые необходимо выполнить при запуске GAN. На протяжении большей части моих тренировок лучше всего работали 3 диска и 1 поколение.
- Я подключил последний вывод генератора после очень эпохи и использовал обратный вызов конца эпохи для отображения четырех сгенерированных изображений:
Функции потерь
Для нашей сети на основе GAN нам потребовалась функция потерь для обучения дискриминатора и функция потерь для обучения генератора:
Дискриминатор:
Функция потерь дискриминатора принимает выходные данные соответствующего критика для каждого из входных изображений домена, а также выходные данные критика для каждого из переведенных изображений. Затем используется среднеквадратическая ошибка для сравнения векторов с включенным дискриминатором, чтобы узнать разницу между двумя доменами.
Генератор:
Генератор получает предсказание дискриминатора переведенных изображений, входных изображений, восстановленных изображений (aToA и bToB) и циклических изображений.
Генератор использует ту же среднеквадратичную ошибку, чтобы оценить предсказание критика переведенных изображений с целью убедить критика в их истинности (выходной вектор 0). Он также использует потерю L1 для сравнения схожести входов с восстановленными и зацикленными выходами. Это сделано потому, что потеря MSE (или L2) недостаточно точна для сравнения значений пикселей.
Обучение
Как уже упоминалось, большое улучшение, сделанное с этой сетью, связано с ее обучением. Хорошие результаты появляются в эпоху 2 и быстро увеличиваются примерно до эпохи 10. На этом этапе я обнаружил, что тренировка замедляется, но продолжает улучшаться после 30 лет.
- Следует отметить, что иногда я обнаруживал, что иногда это обучение приводит к тому, что изображения оказываются в определенном цвете / оттенке или даже просто становятся полностью черными. Если я перезапускаю тренировку, она обычно улучшается. Не уверен, что это функция большинства сетей, но требуется дальнейшее исследование.
Один эксперимент, который я пробовал и который имел умеренный успех, состоял в том, что после обучения сети на моем UNET-UNIT я провел несколько периодов обучения, где единственной функцией потерь было минимизировать разницу между входными данными и восстановленными изображениями. Этот «postGANTrainer» в некоторых случаях успешно повышал резкость изображений, но не приносил никакой пользы в других.
Чтобы проверить свои результаты после тренировки, я сделал эти небольшие блоки кода:
Следующие шаги
Возможные следующие шаги в продолжении этого проекта:
- Реализация NVIDIA FUNIT для перевода изображений с очень ограниченными данными с использованием тех же адаптаций здесь
- Добавлены в AdaIN связи между кодировщиком и декодером
- Добавление еще одного критика, чтобы различать две области в потере генератора, чтобы улучшить результаты перевода
Я относительно новичок в глубоком обучении, и я впервые реализую статью в FastAI. Если вы заметили какие-либо ошибки или способы, которыми я могу улучшить свой код и / или эту статью, сообщите мне! Спасибо за чтение :)