JAX помогает вам писать короткие, простые и молниеносно быстрые вычисления. В отличие от TensorFlow и PyTorch, JAX использует функциональный стиль программирования. То есть функции должны быть чистыми: побочные эффекты не допускаются. На практике мы по-прежнему хотим повторно использовать компоненты, такие как слои нейронной сети. Haiku от DeepMind объединяет мир функционального программирования JAX с объектно-ориентированным программированием (ООП). Короче говоря, Haiku позволяет вам смешивать и сочетать модули — в стиле ООП — и затем преобразовывать их в функциональную программу.

Кроме того, Haiku поставляется с большой библиотекой компонентов нейросети: от простых сверток и модулей внимания до полностью обученных моделей ResNet. И, наконец, мой личный фаворит: Haiku поставляется с генератором псевдослучайных чисел (PRNG) итератором ключей (haiku.PRNGSequence), который облегчает задачу разделения ключей.

Рабочий процесс

✨✨Три простых шага к просветлению✨✨

transform: Сначала напишите функцию, содержащую модули Haiku. Затем вы очищаете свою функцию, украшая ее haiku.transform следующим образом:

import haiku as hk

@hk.transform
def forward(x):
 neural_net = hk.nets.MLP([300, 100, 10])
 return neural_net(x)

Здесь мы создали обычную нейронную сеть (многослойный персептрон или MLP) и использовали x для прогнозирования. Технически forward теперь является преобразованным экземпляром, содержащим две чистые функции: Transformed.init и Transformed.apply.

🔢init: Ваша нейронная сеть в forward имеет множество параметров. Haiku нужно сделать один проход через вашу функцию, чтобы отследить и инициализировать параметры. Вы делаете это, вызывая метод init вашей преобразованной функции с ключом генератора псевдослучайных чисел.

import jax

key_seq = hk.PRNGSequence(42)
params = forward.init(next(key_seq), x)

Это вернет pytree params с начальными весами и смещениями вашей очищенной функции. Один набор для каждого слоя. Обратите внимание, что мы использовали next(key_seq) для генерации нового ключа из ключевой последовательности генератора псевдослучайных чисел Haiku.

🔨apply: этот метод является полным аналогом вашей оригинальной функции forward. Теперь, когда у вас есть сопутствующий params, вы можете вызвать функцию для пакета x следующим образом:

logits = forward.apply(params, next(key_seq), x)

Готово!

Для дальнейшего чтения я рекомендую документацию Основы Haiku для краткого введения в Haiku.