Пробуем новый фреймворк машинного обучения
В ноябре 2018 года на GitHub Google появилась новая библиотека: JAX. В 2020 году произошел скачок популярности, и ходят слухи, что гуглеры предпочитают его TensorFlow v2. Функции, написанные на JAX, выполняются в системах с несколькими GPU или TPU без неудобных вспомогательных библиотек и без перемещения данных из памяти устройства.
Как разработчик, обычно использующий AutoKeras и Transformers для решения задач нейронных сетей, я с осторожным оптимизмом смотрю на JAX. Сам по себе он слишком низкоуровневый, ближе к NumPy. Если вы попробуете одну из фреймворков, которая добавляет поддержку нейронной сети, вы обнаружите растущее сообщество, но не так много примеров или потоков StackOverflow, как PyTorch и TensorFlow.
Для людей, которые верят в будущее JAX, этот пробел - возможность построить следующую большую вещь!
В этом проекте я попытался решить задачу классификации цветов из Kaggle, используя четыре фреймворка JAX: Flax / Linen. , Haiku, Objax и Elegy. Эти сети принадлежат Google, DeepMind и Poets-AI.
Нейронные сети
Если вы не знакомы с циклом обучения, каждый шаг обучения сети включает в себя:
- ввод данных обучения
- сделать прогноз с помощью сети
- используйте потерю, чтобы сравнить прогноз и реальность (эта функция может просто измерить разницу или потерю веса в зависимости от серьезности ошибки)
- используйте оптимизатор для исправления модели (существуют разные стратегии в зависимости от стадии обучения, количества потерь и т. д.)
Этот процесс повторяется для каждой партии и каждой эпохи обучения.
Лен (и лен)
Flax - это высокопроизводительная библиотека нейронных сетей для JAX, разработанная для обеспечения гибкости.
Я считаю, что Flax был первым общедоступным JAX-фреймворком. После того, как я начал писать о своем проекте в Твиттере, разработчик порекомендовал перейти на их бета-версию API под названием Linen. Я взял их пример с Imagenet и внес несколько небольших изменений, чтобы запустить его в CoLab с набором данных Imagenette. Набор данных Imagenet слишком велик для CoLab и имеет много классов (1000), поэтому с 10-классными Imagenette и их дополнительными проектами (Imagewoof) легче поиграть.
Затем я могу внести еще несколько изменений для загрузки задачи цветов в формате TFRecord, представленном на Kaggle.
Вы можете задаться вопросом, эй, если JAX является альтернативой TensorFlow, почему я все еще импортирую TensorFlow и загружаю данные в форматах TFRecord / tf.data.Dataset?
- В этом полустандартном формате доступно множество наборов данных / тестов (есть некоторые досадные различия, но что вы можете сделать)
- Наборы данных можно разделить на пакеты и управлять ими, поэтому мы не пытаемся загружать полный набор данных в оперативную память за один раз.
- В наборах данных изображений есть инструменты для обрезки, зеркального отображения, раскрашивания, маскирования и т. Д., Чтобы расширить данные обучения и сделать модель более гибкой.
Эти инструменты существуют и в экосистеме PyTorch, но JAX ближе к Google.
Хайку
Haiku - это простая библиотека нейронных сетей для JAX, которая позволяет пользователям использовать знакомые объектно-ориентированные модели программирования, обеспечивая при этом полный доступ к преобразованиям чистых функций JAX.
Пример ImageNet от Haiku показывает, как он включает ResNet101 в качестве одного из стандартных строительных блоков. Я успешно адаптировал их пример для Imagenette и набора данных классификации цветов.
Одно особенно странное отличие Haiku (или, по крайней мере, способ настройки этого примера) состоит в том, что мне нужно было установить количество тренировок, test , и примеры проверки, вместо передачи отдельных наборов данных или установки% разделения.
Набор для проверки должен быть кратен размеру оценочного пакета, а Imagenette v0.1.0 имеет только 500 тестовых изображений, поэтому я продолжал получать ошибки на этапе eval после того, как этап поезда успешно использовал тот же код.
Это поднимает еще один вопрос об ошибках! Обсуждение JAX / NVIDIA объясняет, что JIT-компилятор JAX должен принимать хороший код, а не молча давать сбой или вести себя странно, когда вы ошибаетесь. В моем коротком опыте работы с этими фреймворками это означало длинные стеки ошибок, в которых говорилось, что левая сторона не соответствует правой части или pmap получил аргумент ранга _, где зрелая структура могла бы сказать: вы указали неправильное количество классов или ожидаемых размеров: [устройства, размер_пакета, высота, ширина, цвета], чтобы избежать путаницы / ужаса.
Objax
Objax разработан исследователями для исследователей с упором на простоту и понятность.
Objax - это новый фреймворк, который я впервые попробовал. Я смог прочитать их пример Imagenet и напрямую перенести его на задачу о цветах, не практикуясь на Imagenette.
Был озадаченный шаг, когда мне нужно было транспонировать изображения из CHW в HWC [height, width, color_channels] order, и ответы StackOverflow заставили меня задействовать TensorFlow. Я не понимаю, где я могу добавить в свой код стандартные TensorFlow, NumPy или другие операции без потери производительности, которую обещает JAX. А пока я рад, что код работает.
Элегия
Elegy - это фреймворк нейронных сетей, основанный на Jax, вдохновленный Керасом.
Новейшая из стаи, творение poets-ai. Первоначально эта публикация пропустила Elegy, но с тех пор репозиторий был переработан с добавлением новых примеров и модулей ResNet. Их пример Imagenet было легко приспособить к блокноту Imagenette. Здесь мы видим:
- как измеряется убыток (
SparseCategoricalCrossentropy
) - создание оптимизатора для конкретной проблемы
- оболочка
tfds2jax_generator
для устранения неудобств между наборами данных TensorFlow / пакетной обработкой и циклами JAX.
Дополнительные библиотеки, использующие JAX в нейронных сетях, seq2seq, вероятностное программирование и т. Д .: https://news.ycombinator.com/item?id=22814870
Обновления?
Эта статья написана в октябре 2020 года. Последние рекомендуемые библиотеки и руководства можно найти на https://github.com/n2cholas/awesome-jax или на моей странице github.com/mapmeld/use-this-now/blob/main. /README.md#jax-tutorials