Пробуем новый фреймворк машинного обучения

В ноябре 2018 года на GitHub Google появилась новая библиотека: JAX. В 2020 году произошел скачок популярности, и ходят слухи, что гуглеры предпочитают его TensorFlow v2. Функции, написанные на JAX, выполняются в системах с несколькими GPU или TPU без неудобных вспомогательных библиотек и без перемещения данных из памяти устройства.

Как разработчик, обычно использующий AutoKeras и Transformers для решения задач нейронных сетей, я с осторожным оптимизмом смотрю на JAX. Сам по себе он слишком низкоуровневый, ближе к NumPy. Если вы попробуете одну из фреймворков, которая добавляет поддержку нейронной сети, вы обнаружите растущее сообщество, но не так много примеров или потоков StackOverflow, как PyTorch и TensorFlow.

Для людей, которые верят в будущее JAX, этот пробел - возможность построить следующую большую вещь!
В этом проекте я попытался решить задачу классификации цветов из Kaggle, используя четыре фреймворка JAX: Flax / Linen. , Haiku, Objax и Elegy. Эти сети принадлежат Google, DeepMind и Poets-AI.

Нейронные сети

Если вы не знакомы с циклом обучения, каждый шаг обучения сети включает в себя:

  • ввод данных обучения
  • сделать прогноз с помощью сети
  • используйте потерю, чтобы сравнить прогноз и реальность (эта функция может просто измерить разницу или потерю веса в зависимости от серьезности ошибки)
  • используйте оптимизатор для исправления модели (существуют разные стратегии в зависимости от стадии обучения, количества потерь и т. д.)

Этот процесс повторяется для каждой партии и каждой эпохи обучения.

Лен (и лен)

Flax - это высокопроизводительная библиотека нейронных сетей для JAX, разработанная для обеспечения гибкости.

Я считаю, что Flax был первым общедоступным JAX-фреймворком. После того, как я начал писать о своем проекте в Твиттере, разработчик порекомендовал перейти на их бета-версию API под названием Linen. Я взял их пример с Imagenet и внес несколько небольших изменений, чтобы запустить его в CoLab с набором данных Imagenette. Набор данных Imagenet слишком велик для CoLab и имеет много классов (1000), поэтому с 10-классными Imagenette и их дополнительными проектами (Imagewoof) легче поиграть.

Затем я могу внести еще несколько изменений для загрузки задачи цветов в формате TFRecord, представленном на Kaggle.
Вы можете задаться вопросом, эй, если JAX является альтернативой TensorFlow, почему я все еще импортирую TensorFlow и загружаю данные в форматах TFRecord / tf.data.Dataset?

  • В этом полустандартном формате доступно множество наборов данных / тестов (есть некоторые досадные различия, но что вы можете сделать)
  • Наборы данных можно разделить на пакеты и управлять ими, поэтому мы не пытаемся загружать полный набор данных в оперативную память за один раз.
  • В наборах данных изображений есть инструменты для обрезки, зеркального отображения, раскрашивания, маскирования и т. Д., Чтобы расширить данные обучения и сделать модель более гибкой.

Эти инструменты существуют и в экосистеме PyTorch, но JAX ближе к Google.

Хайку

Haiku - это простая библиотека нейронных сетей для JAX, которая позволяет пользователям использовать знакомые объектно-ориентированные модели программирования, обеспечивая при этом полный доступ к преобразованиям чистых функций JAX.

Пример ImageNet от Haiku показывает, как он включает ResNet101 в качестве одного из стандартных строительных блоков. Я успешно адаптировал их пример для Imagenette и набора данных классификации цветов.
Одно особенно странное отличие Haiku (или, по крайней мере, способ настройки этого примера) состоит в том, что мне нужно было установить количество тренировок, test , и примеры проверки, вместо передачи отдельных наборов данных или установки% разделения.

Набор для проверки должен быть кратен размеру оценочного пакета, а Imagenette v0.1.0 имеет только 500 тестовых изображений, поэтому я продолжал получать ошибки на этапе eval после того, как этап поезда успешно использовал тот же код.

Это поднимает еще один вопрос об ошибках! Обсуждение JAX / NVIDIA объясняет, что JIT-компилятор JAX должен принимать хороший код, а не молча давать сбой или вести себя странно, когда вы ошибаетесь. В моем коротком опыте работы с этими фреймворками это означало длинные стеки ошибок, в которых говорилось, что левая сторона не соответствует правой части или pmap получил аргумент ранга _, где зрелая структура могла бы сказать: вы указали неправильное количество классов или ожидаемых размеров: [устройства, размер_пакета, высота, ширина, цвета], чтобы избежать путаницы / ужаса.

Objax

Objax разработан исследователями для исследователей с упором на простоту и понятность.

Objax - это новый фреймворк, который я впервые попробовал. Я смог прочитать их пример Imagenet и напрямую перенести его на задачу о цветах, не практикуясь на Imagenette.
Был озадаченный шаг, когда мне нужно было транспонировать изображения из CHW в HWC [height, width, color_channels] order, и ответы StackOverflow заставили меня задействовать TensorFlow. Я не понимаю, где я могу добавить в свой код стандартные TensorFlow, NumPy или другие операции без потери производительности, которую обещает JAX. А пока я рад, что код работает.

Элегия

Elegy - это фреймворк нейронных сетей, основанный на Jax, вдохновленный Керасом.

Новейшая из стаи, творение poets-ai. Первоначально эта публикация пропустила Elegy, но с тех пор репозиторий был переработан с добавлением новых примеров и модулей ResNet. Их пример Imagenet было легко приспособить к блокноту Imagenette. Здесь мы видим:

  • как измеряется убыток (SparseCategoricalCrossentropy)
  • создание оптимизатора для конкретной проблемы
  • оболочка tfds2jax_generator для устранения неудобств между наборами данных TensorFlow / пакетной обработкой и циклами JAX.

Дополнительные библиотеки, использующие JAX в нейронных сетях, seq2seq, вероятностное программирование и т. Д .: https://news.ycombinator.com/item?id=22814870

Обновления?

Эта статья написана в октябре 2020 года. Последние рекомендуемые библиотеки и руководства можно найти на https://github.com/n2cholas/awesome-jax или на моей странице github.com/mapmeld/use-this-now/blob/main. /README.md#jax-tutorials