Недавно я участвовал в конкурсе по машинному обучению на Kaggle. Это был трек Google Open Images 2019 Challenge - Object Detection. Я был очень ограничен во времени: у меня было всего четыре недели, а обучение одной модели занимает около двух недель с 4 GPU 1080Ti. Но мне посчастливилось получить несколько кредитов TPU от Google. Поэтому я решил попробовать. Вот история.

Что такое ТПУ?

TPU - это специализированные микросхемы, которые хорошо справляются с матричным умножением. Они идеально подходят для глубокого обучения. На одном TPU модель обычно обучается в 10–100 раз быстрее, чем на одном графическом процессоре.

В настоящее время есть TPU v2 и v3; v3 имеет вдвое больше памяти и работает быстрее (в 2 раза быстрее, согласно спецификациям; я наблюдал увеличение скорости в 1,5 раза). Я использовал конфигурации v2–8 и v2–8 (8 ядер на чип, это минимальное количество).

TPU работает как выделенный компьютер в центре обработки данных Google. У него нет хранилища, поэтому любые данные должны передаваться в облачное хранилище Google и исходить из него. Чтобы использовать TPU, мы должны создать пару виртуальной машины и виртуального устройства TPU в панели управления GCP. Вы запускаете свою программу на своей виртуальной машине, она подключается к устройству TPU и просто отправляет ему команды. Таким образом, входные данные считываются из Google Cloud Storage, журналы передаются обратно в управляющую виртуальную машину, а все результаты сохраняются в Google Cloud Storage (включая данные TensorBoard и сохраненные модели).

API и ограничения

Хорошо, TPU мощные, но все имеет свою цену. К сожалению, реализация PyTorch не была полностью готова, поэтому мне пришлось использовать TensorFlow. TensorFlow 2.0 также не был готов с точки зрения поддержки TPU, поэтому я остановился на старом-добром TF 1.14.

В настоящее время TPU требует полностью статического графа TF. Это подразумевает некоторые ограничения на модели, которые вы можете использовать. Например, я не мог использовать произвольный размер входного изображения для RetinaNet. Жесткий отрицательный майнинг невозможен.

Входные данные передаются через сеть, поэтому формат данных должен быть эффективным. Незакрепленные файлы изображений, вероятно, приведут к тому, что сеть станет узким местом. Лучший способ сделать это - использовать файлы TFRecord. Здесь мы также теряем некоторую свободу: я не знаю, как реализовать сбалансированную выборку с TFRecords. На самом деле, у меня есть идея: может быть, мы сможем добиться этого с помощью специально созданных файлов TFRecord, по одному на класс (я не пробовал).

Модели обнаружения объектов

Я использовал репо от Google: https://github.com/tensorflow/tpu/tree/master/. Он реализует ResNet и EfficientNet для классификации изображений и RetinaNet (https://arxiv.org/abs/1708.02002) для обнаружения объектов. Вот моя вилка этого репо с несколькими исправлениями и настройками: https://github.com/artyompal/tpu_models.

Это репо поставляется с очень хорошими учебниками, которые я настоятельно рекомендую, если вы хотите попробовать обучение на TPU:

Подробнее об этой конкретной реализации RetinaNet

Первоначально единственной поддерживаемой магистралью была ResNet50 / 101/152/200. Я добавил поддержку EfficientNet. Я также добавил поддержку магистрали SE-ResNext, но она не окончательная (работает слишком медленно, возможно, неправильный порядок каналов изображения).

Кроме того, эта реализация поддерживает обычную FPN и ее лучшую версию, называемую NAS-FPN: https://arxiv.org/abs/1904.07392. Последний работает значительно лучше. Как следует из названия, это результат поиска нейронной архитектуры для лучшей архитектуры FPN.

Кроме того, магистраль RetinaNet и ResNet поддерживает DropBlock: https://arxiv.org/abs/1810.12890. Это лучшая альтернатива отсеву в сверточных сетях, и она действительно улучшает оценку как для моделей классификации, так и для моделей обнаружения объектов.

Набор данных Open Images

Это огромный набор данных с 1,8 млн изображений, аннотированных 12 млн ограничивающих рамок. Всего 500 классов. Этикетки шумные и не всегда правильные. Набор данных имеет серьезный дисбаланс классов: AFAIR, наиболее распространенный класс имеет 440K аннотаций, а наименее часто используемый класс - только 22. Чтобы справиться с дисбалансом классов, я разделил классы на 6 групп по частоте.

Кроме того, классы наборов данных образуют иерархию:

Я использовал только листовые классы для обучения, что на самом деле было ошибкой. Результат мог бы быть лучше, если бы в модели было больше данных. Проблема в том, что не все объекты строго помечены. Например, некоторые люди обозначены как Man, некоторые - как Woman, а некоторые - как Person (родитель обоих этих классов).

Мой тренировочный процесс

По сути, я тренировал RetinaNet со всеми этими прибамбасами в некоторых комбинациях. Обучение длилось 1-2 дня на одном ТПУ. TPUv3 допускает размер пакета до 64 с образами 1024x1024 (у него 16 ГБ памяти HBM на каждое ядро, как минимум 8 ядер).

Еще одна проблема заключалась в том, что у меня были предварительно натренированные веса только для ResNet50. Но у меня было много возможностей TPU, поэтому я просто загрузил ImageNet и сделал свои собственные предварительно обученные модели с ResNet101 / 152/200. Это действительно отлично сработало.

Я добавил поддержку магистрали EfficientNet, но, к сожалению, она не сработала. Я обучил несколько моделей с помощью EfficientNet, но прошлой ночью столкнулся с какой-то странной ошибкой экспорта, поэтому они не были частью окончательного решения.

Ансамбль

Поскольку это было для Kaggle, мы должны построить ансамбль мощных моделей, чтобы получить наилучший результат. Для этого можно использовать подавление Non-Max, но лучшим подходом является Soft-NMS (https://arxiv.org/abs/1704.04503). Код выглядит так: https://github.com/artyompal/tpu_models/blob/master/scripts/inference/soft_nms.pyx.

Последние мысли

TPU действительно мощные. Конечно, нам действительно нужна стабильная реализация PyTorch :). Я надеюсь, что TPU упростят глубокое обучение и частично остановят монополию Nvidia, что, надеюсь, сделает облачные вычисления более доступными.