Обучение нейронной сети в десять раз быстрее с помощью Jax на TPU

В наши дни все крутые ребята, кажется, в восторге от JAX. Deepmind широко использует его для своих исследований и даже строит на его основе собственную экосистему. Борис Дайма и его команда в кратчайшие сроки построили DALL·E Mini с использованием JAX и TPU. Его определенно стоит проверить на Hugging Face, где вы уже найдете более 5000 моделей, написанных на JAX. Но что такое JAX и почему он такой особенный? Согласно их веб-сайту, JAX предлагает автоматическую дифференциацию, векторизацию и своевременную компиляцию как для графических процессоров, так и для TPU посредством компонуемых преобразований. Звучит сложно? Не волнуйтесь, в этом посте мы проведем для вас экскурсию и покажем, как работает JAX, чем он отличается от Tensorflow/Pytorch и почему мы считаем его очень интересным фреймворком.

Что такое ДЖАКС?

JAX — это высокопроизводительная платформа для числовых вычислений и машинного обучения от Google Research, которая работает очень быстро на GPU и TPU, не беспокоясь о низкоуровневых деталях. Цель JAX заключалась в создании фреймворка, сочетающего высокую производительность с выразительностью и простотой использования Python, чтобы исследователи могли экспериментировать с новыми моделями и методами без необходимости высокооптимизированных низкоуровневых реализаций C/C++. Это достигается за счет использования компилятора Google XLA (ускоренной линейной алгебры) для создания эффективного машинного кода, а не использования предварительно скомпилированных ядер. Одна из замечательных особенностей JAX заключается в том, что он не зависит от ускорителей, а это означает, что один и тот же код Python может эффективно работать как на GPU, так и на TPU.

JAX работает через составные преобразования функций, это означает, что JAX берет функцию и создает новую функцию, которая интерпретируется по-разному, и что несколько преобразований могут быть объединены в цепочку. Например, автоматическое дифференцирование — это преобразование, которое генерирует производную функции, в то время как автоматическая векторизация берет функцию, которая работает с одной точкой данных, и преобразует ее в функцию, которая работает с пакетом точек данных. Благодаря этим преобразованиям JAX позволяет программисту оставаться в мире Python высокого уровня и позволяет компилятору выполнять тяжелую работу, создавая высокоэффективный код, необходимый для обучения сложных моделей. Мы рассмотрим эти преобразования и применим их в примере, где мы строим простой многослойный персептрон.

Чем он отличается от Tensorflow и Pytorch?

JAX — это среда, ориентированная на компилятор, что означает, что компилятор отвечает за преобразование функций Python в эффективный машинный код. С другой стороны, Tensorflow и Pytorch имеют предварительно скомпилированные ядра GPU и TPU для каждой операции. Во время выполнения программы TensorFlow каждая операция отправляется отдельно. Хотя сами операции очень хорошо оптимизированы, их объединение требует большого количества операций с памятью, что приводит к снижению производительности. Компилятор XLA может генерировать код для всей функции. Он может использовать всю эту информацию для объединения операций и экономии тонны операций с памятью и, таким образом, генерировать в целом более быстрый код.

JAX также легче, чем Tensorflow и Pytorch, потому что нет необходимости реализовывать каждую операцию, функцию или модель отдельно. Вместо этого JAX реализует NumPy API с более простыми и более низкоуровневыми операциями, которые можно использовать в качестве строительных блоков и объединять компилятором в сложные модели и функции.

Компилятор, гораздо более мощный, чем вы думаете

Дизайн, ориентированный на компилятор, намного мощнее, чем может показаться на первый взгляд. С компилятором больше нет необходимости реализовывать низкоуровневый код ускорителя. Это позволяет исследователям значительно повысить свою производительность и открывает двери для экспериментов с новыми архитектурами моделей. Исследователи даже могут экспериментировать с GPU и TPU без необходимости переписывать свой код. Но как это работает?

JAX компилируется не напрямую в машинный код, а в промежуточное представление, независимое от высокоуровневого кода Python и машинного кода. Компилятор разделен на внешний интерфейс, который компилирует функции Python в IR, и бэкенд, который компилирует IR в машинный код для конкретной платформы. Этот дизайн не нов, примером компилятора, который также следует этому дизайну, является LLVM. Существуют интерфейсы как для C, так и для Rust, которые переводят высокоуровневый код в LLVM IR. Затем серверная часть может генерировать машинный код для различных поддерживаемых типов машин, независимо от того, был ли исходный код написан на C или Rust.

Это очень важно, потому что благодаря такому гибкому дизайну можно создать новый ускоритель, написать для него бэкенд XLA, и ваш JAX-код, который ранее выполнялся на GPU/TPU, может выполняться на новом ускорителе. С другой стороны, вы также можете создать структуру на другом языке программирования, который компилируется в JAX IR, и вы можете использовать графические процессоры и TPU благодаря XLA.

Если этот подход на основе компилятора работает намного лучше, чем предварительно скомпилированные ядра, почему Tensorflow и Pytorch не использовали его с самого начала? Ответ довольно прост: очень сложно разработать хороший числовой компилятор. Благодаря автоматическому дифференцированию, векторизации и jit-компиляции JAX имеет в своем арсенале несколько действительно мощных инструментов. Однако JAX также не является серебряной пулей, все эти преимущества имеют небольшую цену, вам нужно изучить несколько новых приемов и концепций, связанных с функциональным программированием.

Замечание по функциональному программированию.

JAX не может преобразовывать любую функцию Python, он может преобразовывать только чистые функции. Чистая функция может быть определена как функция, которая зависит только от своих входных данных, что означает, что для данного входного значения x она всегда будет возвращать один и тот же результат y и что она не производит никаких побочных эффектов, таких как операции ввода-вывода или изменение глобальных переменных. . Динамизм Python означает, что поведение функции меняется в зависимости от типов ее входных данных, и JAX хочет использовать этот динамизм, преобразовывая функции во время выполнения. В начале преобразования JAX проверяет, что делает функция для набора заданных входных данных, и преобразует функцию на основе этой информации. Под капотом JAX отслеживает функцию, как и интерпретатор Python. Разрешая использовать только чистые функции, своевременное преобразование функций становится намного проще и быстрее.

Представьте, что трассировщик должен иметь дело с побочными эффектами, такими как ввод-вывод, это означает, что может произойти неожиданное поведение, например, пользователь введет неверные данные, что значительно усложняет создание эффективного кода, особенно когда в игре есть ускорители. Глобальные переменные могут изменяться между двумя вызовами функций и, таким образом, полностью менять поведение функции, в которой они используются, что делает преобразованную функцию недействительной. Если вас интересуют компиляторы и мельчайшие детали того, как работает трассировка JAX, мы рекомендуем вам ознакомиться с документацией для получения более подробной информации о ее внутренней работе.

Единственный недостаток JAX заключается в том, что он не может проверить, является ли функция чистой функцией. Программист должен убедиться, что он пишет чистые функции, иначе JAX преобразует функцию с неожиданным поведением.

Представление объектов с состоянием с помощью Pytrees

Работа с чистыми функциями также влияет на то, как используются структуры данных. В других фреймворках модели машинного обучения часто представляются в виде состояний, однако это противоречит парадигме функционального программирования, поскольку это мутация глобального состояния. Чтобы решить эту проблему, JAX вводит pytrees, древовидные структуры, построенные из контейнероподобных объектов Python. Контейнероподобные классы можно зарегистрировать в реестре pytree, который по умолчанию содержит списки, кортежи и словари. Pytree могут содержать другие pytree, а классы, не зарегистрированные в реестре pytree, считаются листьями. Листья можно рассматривать как неизменяемые входные данные для чистой функции. Для каждого класса в реестре pytree есть функция, которая преобразует pytree в кортеж с его дочерними элементами и необязательными метаданными, а также функция, которая преобразует дочерние элементы и метаданные обратно в контейнероподобный тип. Эти функции можно использовать для обновления модели или любых других объектов с состоянием, которые вы используете.

Давайте изменим код!

Прежде чем мы углубимся в наш пример MLP, мы покажем наиболее важные преобразования в JAX.

Автоматическая дифференциация

Первое преобразование — это автоматическое дифференцирование, когда мы берем функцию Python в качестве входных данных и возвращаем функцию, представляющую градиент этой функции. Отличительной особенностью автодиффа в JAX является то, что он может различать функции Python, которые используют, и контейнеры Python, условные операторы, циклы и т. д. В следующем примере мы создаем функцию, представляющую градиент функции tanh. Поскольку преобразования JAX являются составными, мы можем использовать n вложенных вызовов функции grad для преобразования для вычисления n-й производной.

Автоматическая дифференциация JAX — это мощный и обширный инструмент. Если вы хотите узнать больше о том, как она работает, мы рекомендуем вам прочитать Поваренную книгу JAX Autodiff.

Автоматическая векторизация

При обучении модели вы обычно распространяете набор обучающих выборок по вашей модели. Таким образом, при реализации модели вам придется думать о своей функции прогнозирования как о функции, которая принимает пакет выборок и возвращает прогноз для каждой выборки. Однако это может значительно усложнить реализацию, а также снизить читабельность функции по сравнению с функцией, которая будет работать с одним образцом. Наступает второе преобразование: автоматическая векторизация. Мы пишем нашу функцию так, как если бы мы обрабатывали только один образец, тогда vmap преобразует его в векторизованную версию.

В начале vmap может быть немного сложным, особенно при работе с большими измерениями, но это действительно мощная трансформация. Мы рекомендуем вам ознакомиться с некоторыми примерами в документации, чтобы полностью понять его потенциал.

Своевременная компиляция

Третье преобразование функции — это своевременная компиляция. Цель этого преобразования — повысить производительность, распараллелить код и запустить его на ускорителе. JAX компилируется не напрямую в машинный код, а в промежуточное представление. Это промежуточное представление не зависит от кода Python и машинного кода ускорителя. Затем компилятор XLA возьмет промежуточное представление и скомпилирует его в эффективный машинный код.

Не всегда легко решить, когда и какой код вам следует скомпилировать, чтобы оптимально использовать компилятор, мы рекомендуем вам ознакомиться с документацией. Позже в этом блоге мы немного углубимся в разработку компилятора и в то, почему это делает JAX такой мощной структурой.

Обучение MLP за 5 минут с использованием TPU

Теперь, когда мы узнали о самых важных преобразованиях, мы готовы применить эти знания на практике. Мы реализуем MLP с нуля для классификации изображений MNIST и очень быстро обучаем его на TPU. Наша нейронная сеть будет иметь входной слой из 728 входных переменных, за которым следуют два скрытых слоя с 512 и 256 нейронами соответственно и выходной слой с узлом для каждого класса.

Инициализация модели

Первое, что нам нужно сделать, это создать структуру, которая представляет нашу модель. В качестве входных данных нашей функции инициализации у нас есть список с количеством узлов в каждом слое нашей нейронной сети. У нас есть входной слой, равный количеству пикселей изображения, за которым следуют два скрытых слоя с 512 и 256 нейронами соответственно, и выходной слой, равный количеству классов. Мы используем массивы JAX numpy для инициализации модели на ускорителе, избегая ручного копирования этих данных.

Обратите внимание, что генерация случайных чисел немного отличается от numpy. Мы хотим иметь возможность генерировать случайные числа на параллельных ускорителях, и нам нужен генератор случайных чисел, который хорошо работает с парадигмой функционального программирования. Алгоритм Numpy для генерации случайных чисел не очень подходит для этих целей. Ознакомьтесь с примечаниями к дизайну и документацией JAX для получения дополнительной информации.

Прогноз

Наш следующий шаг — написать функцию прогнозирования, которая будет назначать метки пакету изображений. Мы будем использовать автоматическую векторизацию, чтобы преобразовать функцию, которая принимает одно изображение в качестве входных данных и выводит метку, в функцию, которая предсказывает метки для пакета входных данных. Написание функции прогнозирования не очень сложно, мы проходим через скрытые слои сети и применяем веса и смещения посредством умножения матриц и сложения векторов, а также применяем функцию активации RELU. В конце мы вычисляем выходную метку с помощью функции RealSoftMax. Когда у нас есть функция для маркировки одного изображения, мы можем преобразовать его с помощью vmap, чтобы он мог обрабатывать пакет входных данных.

Функция потери

Функция потерь берет пакет изображений и вычисляет среднюю абсолютную ошибку. Мы вызываем наши пакетные прогнозы и вычисляем метку для каждого изображения, сравниваем ее с метками истинности, закодированными одним горячим кодом, и вычисляем среднее количество ошибок.

Функция обновления

Теперь, когда у нас есть функция прогнозирования и потерь, мы реализуем функцию обновления для итеративного обновления нашей модели на каждом этапе обучения. Наша функция обновления принимает пакет изображений и их метки истинности вместе с текущей моделью и скоростью обучения. Мы рассчитываем как значение потерь, так и значение их градиента. Мы обновляем модель, используя скорость обучения и градиенты потерь. Поскольку мы хотим скомпилировать эту функцию, мы должны преобразовать обновленную модель в pytree. Мы также возвращаем значение потери для контроля точности.

Теперь, когда у нас есть функция обновления, мы собираемся скомпилировать ее, чтобы она могла работать на TPU и значительно повысить ее производительность. Вложенные функции, вызываемые в функции обновления, также будут скомпилированы и оптимизированы. Причина, по которой мы применяем преобразование компиляции только при обновлении, а не к каждой функции отдельно, заключается в том, что мы хотим предоставить компилятору как можно больше информации для работы, чтобы он мог максимально оптимизировать код.

Обучение нашей модели

Мы можем определить функцию точности (и, возможно, другие показатели) и создать цикл обучения, используя нашу функцию обновления и исходную модель в качестве входных данных. Теперь мы готовы обучить нашу модель с помощью TPU или GPU.

Заключение

Фух, сегодня мы многому научились. Сначала мы начали описывать JAX как фреймворк с составными преобразованиями функций. Четыре основных преобразования — это автоматическая векторизация, автоматическое распараллеливание на нескольких ускорителях, автоматическое дифференцирование функций Python и функции JIT-компиляции для их запуска на ускорителях. Мы углубились во внутреннюю работу JAX и узнали, как он может создавать такие эффективные функции, которые работают как на GPU, так и на TPU, путем компиляции в IR, который затем преобразуется в вызовы XLA. Этот подход позволяет исследователям экспериментировать с новыми методами машинного обучения, не беспокоясь о низкоуровневой высокооптимизированной версии своего кода. Мы надеемся, что инженеры-программисты также взволнованы тем, что новые библиотеки могут быть построены поверх JAX, а потенциальные ускорители могут быть быстро приняты.