JAX — это функциональная вычислительная библиотека, разработанная Google, она предоставила некоторые примитивные строительные блоки для векторных вычислений и автоматического дифференцирования.

Самый простой способ объяснить обучение модели нейронной сети состоит в том, что оно нацелено на поиск параметров модели, чтобы минимизировать потери между фактическим значением и прогнозом модели (также известная как функция прямой подачи). Затем обучение формулируется как задача оптимизации. Один из распространенных способов ее решения — получить градиент функции потерь. Ядро JAX обеспечивает автоматическое дифференцирование любой функции jax по jax.grad. Это делает JAX идеально подходящим для обучения нейронных сетей.

Чтобы использовать автоматическое дифференцирование JAX, мы должны:

(1) преобразовать функцию потерь в чистую функцию JAX без сохранения состояния,

(2) отделить параметры модели от класса объекта.

Это можно легко сделать с помощью библиотеки Python Haiku. Ядром библиотеки Haiku являются hk.Module и hk.transform.

Помните, что наша цель — минимизировать функцию потерь, вычислив ее градиент. Для этого функция потерь должна быть чистой функцией JAX.

В случае простой задачи регрессии функция потерь может быть просто средним квадратом разницы прогноза и фактического значения: jnp.mean(jnp.square(pred — y)). JAX уже предоставляет функции mean и square так же, как и numpy. Однако прогноз pred вычисляется функцией прямой подачи модели, которую необходимо преобразовать в чистую функцию JAX.

Функция прямой подачи определена вunroll_net, о которой я расскажу подробнее позже. И преобразование выполняется model = hk.transform(unroll_net), где он превращает unroll_net в объект (model), который содержит init и apply чистые методы JAX. Эти 2 метода будут отделять параметры модели от модели нейронной сети с отслеживанием состояния.

Метод model.init случайным образом инициализирует параметры модели фиктивными даннымиsample_x. Он не будет хранить параметры, а передаст их как возвращаемое значение params для дальнейшего обновления позже при обучении модели.

Метод model.apply выполнит вычисление прямой подачи, передав параметры инициализированной модели params и каждую точку данных x. Вот код:

Преобразование в чистую функцию JAX:

model = hk.transform(unroll_net)

Инициализировать параметры модели (params):

# Initialize model parameters with dummy data sample_x
rng = jax.random.PRNGKey(428)
params = model.init(rng, sample_x)

Расчет прямой подачи модели:

pred, _ = model.apply(params, None, x)

Окончательная функция потерь:

def loss(params, x, y):
  pred, _ = model.apply(params, None, x)
  return jnp.mean(jnp.square(pred - y))

Поскольку model.apply и jnp.mean/jnp.square являются чистыми функциями JAX, мы затем вычисляем градиент этой чистой функции потерь JAX. jax.value_and_grad(loss) вернет функцию, которая принимает те же аргументы, что и функция loss, но также вернет градиент (grads) и значение потерь (l).

l, grads = jax.value_and_grad(loss)(params, x, y)

Градиент используется для обновления параметров модели (params) оптимизатором ADAM opt = optax.adam(1e-4).

Инициализируйте параметры оптимизатора:

opt_state = opt.init(params)

Обновление параметров модели и параметров оптимизатора

grads, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, grads)

Последний цикл обучения будет повторять вычисление градиента и потерь и обновлять параметры модели для каждой партии точек данных.

Теперь мы понимаем, как в целом выполняется обучение модели. Затем давайте углубимся в определение модели, начав с линейных слоев нейронной сети для классификации. Простая функция прямой подачи — просто передать точку данных (x) в последовательности (mlp) 3 линейных слоев (hk.Linear).

def forward_fn(vec: jnp.ndarray) -> jnp.ndarray:  
    mlp = hk.Sequential([      
               hk.Flatten(),      
               hk.Linear(300), jax.nn.relu,      
               hk.Linear(100), jax.nn.relu,  
               hk.Linear(NUM_CLASSES),  ])  
    return mlp(x)

Помимо базового линейного слоя (hk.Linear), в Haiku уже определено множество других предопределенных модулей, таких как модель LSTM. Однако подача вперед более сложна. Вот определение функции прямой подачи модели (unroll_net):

def unroll_net(seqs: jnp.ndarray):
  """Unrolls an LSTM over seqs, mapping each output to a scalar."""
  # seqs is [T, B, F].
  core = hk.LSTM(32)
  batch_size = seqs.shape[1]
  outs, state = 
hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
  return hk.BatchApply(hk.Linear(1))(outs), state

Основным сетевым модулем является hk.LSTM(32), который имеет 32 скрытых состояния. Поскольку RNN/LSTM содержит скрытые состояния памяти, мы не можем просто передать данные в сетевой модуль, как это было сделано в линейном случае. Это должно быть выполнено hk.dynamic_unroll. Он вызывает core для каждого элемента входной последовательности (seqs) в цикле, перенося состояние.

Хотя в Haiku есть много предопределенных общих модулей, иногда нам нужно будет написать их самим, если они не подходят для нашего приложения. В таком случае пользовательский модуль можно определить в Haiku, расширив classhk.Module, который аналогичен модулю PyTorch, расширяющему torch.nn.Module.

Нам нужно реализовать только 2 метода класса: __init__ и __call__, которые на самом деле являются прямой функцией в PyTorch. Ниже приведен пример, показанный на Haiku GitHub:

class MyLinear(hk.Module):
def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size
def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

Чтобы отделить состояние от модуля, каждый именованный параметр модели должен быть получен через hk.get_parameter в проходе прямой подачи (__call__). Причина отказа от использования свойств объекта для хранения параметров модели заключается в том, что функция может быть преобразована в чистую функцию JAX с использованием hk.transform, как мы объясняли в начале.

В папке Haiku Git есть много примеров кода. Я могу легко изменить haiku_lstms.ipynb, чтобы создать предсказатель временных рядов для данных котировок акций.