JAX — это функциональная вычислительная библиотека, разработанная Google, она предоставила некоторые примитивные строительные блоки для векторных вычислений и автоматического дифференцирования.
Самый простой способ объяснить обучение модели нейронной сети состоит в том, что оно нацелено на поиск параметров модели, чтобы минимизировать потери между фактическим значением и прогнозом модели (также известная как функция прямой подачи). Затем обучение формулируется как задача оптимизации. Один из распространенных способов ее решения — получить градиент функции потерь. Ядро JAX обеспечивает автоматическое дифференцирование любой функции jax по jax.grad
. Это делает JAX идеально подходящим для обучения нейронных сетей.
Чтобы использовать автоматическое дифференцирование JAX, мы должны:
(1) преобразовать функцию потерь в чистую функцию JAX без сохранения состояния,
(2) отделить параметры модели от класса объекта.
Это можно легко сделать с помощью библиотеки Python Haiku. Ядром библиотеки Haiku являются hk.Module
и hk.transform
.
Помните, что наша цель — минимизировать функцию потерь, вычислив ее градиент. Для этого функция потерь должна быть чистой функцией JAX.
В случае простой задачи регрессии функция потерь может быть просто средним квадратом разницы прогноза и фактического значения: jnp.mean(jnp.square(pred — y))
. JAX уже предоставляет функции mean
и square
так же, как и numpy
. Однако прогноз pred
вычисляется функцией прямой подачи модели, которую необходимо преобразовать в чистую функцию JAX.
Функция прямой подачи определена вunroll_net
, о которой я расскажу подробнее позже. И преобразование выполняется model = hk.transform(unroll_net)
, где он превращает unroll_net
в объект (model
), который содержит init
и apply
чистые методы JAX. Эти 2 метода будут отделять параметры модели от модели нейронной сети с отслеживанием состояния.
Метод model.init
случайным образом инициализирует параметры модели фиктивными даннымиsample_x
. Он не будет хранить параметры, а передаст их как возвращаемое значение params
для дальнейшего обновления позже при обучении модели.
Метод model.apply
выполнит вычисление прямой подачи, передав параметры инициализированной модели params
и каждую точку данных x
. Вот код:
Преобразование в чистую функцию JAX:
model = hk.transform(unroll_net)
Инициализировать параметры модели (params
):
# Initialize model parameters with dummy data sample_x rng = jax.random.PRNGKey(428) params = model.init(rng, sample_x)
Расчет прямой подачи модели:
pred, _ = model.apply(params, None, x)
Окончательная функция потерь:
def loss(params, x, y): pred, _ = model.apply(params, None, x) return jnp.mean(jnp.square(pred - y))
Поскольку model.apply
и jnp.mean/jnp.square
являются чистыми функциями JAX, мы затем вычисляем градиент этой чистой функции потерь JAX. jax.value_and_grad(loss)
вернет функцию, которая принимает те же аргументы, что и функция loss
, но также вернет градиент (grads
) и значение потерь (l
).
l, grads = jax.value_and_grad(loss)(params, x, y)
Градиент используется для обновления параметров модели (params
) оптимизатором ADAM opt = optax.adam(1e-4)
.
Инициализируйте параметры оптимизатора:
opt_state = opt.init(params)
Обновление параметров модели и параметров оптимизатора
grads, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, grads)
Последний цикл обучения будет повторять вычисление градиента и потерь и обновлять параметры модели для каждой партии точек данных.
Теперь мы понимаем, как в целом выполняется обучение модели. Затем давайте углубимся в определение модели, начав с линейных слоев нейронной сети для классификации. Простая функция прямой подачи — просто передать точку данных (x
) в последовательности (mlp
) 3 линейных слоев (hk.Linear
).
def forward_fn(vec: jnp.ndarray) -> jnp.ndarray: mlp = hk.Sequential([ hk.Flatten(), hk.Linear(300), jax.nn.relu, hk.Linear(100), jax.nn.relu, hk.Linear(NUM_CLASSES), ]) return mlp(x)
Помимо базового линейного слоя (hk.Linear
), в Haiku уже определено множество других предопределенных модулей, таких как модель LSTM. Однако подача вперед более сложна. Вот определение функции прямой подачи модели (unroll_net
):
def unroll_net(seqs: jnp.ndarray): """Unrolls an LSTM over seqs, mapping each output to a scalar.""" # seqs is [T, B, F]. core = hk.LSTM(32) batch_size = seqs.shape[1] outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size)) return hk.BatchApply(hk.Linear(1))(outs), state
Основным сетевым модулем является hk.LSTM(32)
, который имеет 32 скрытых состояния. Поскольку RNN/LSTM содержит скрытые состояния памяти, мы не можем просто передать данные в сетевой модуль, как это было сделано в линейном случае. Это должно быть выполнено hk.dynamic_unroll
. Он вызывает core
для каждого элемента входной последовательности (seqs
) в цикле, перенося состояние.
Хотя в Haiku есть много предопределенных общих модулей, иногда нам нужно будет написать их самим, если они не подходят для нашего приложения. В таком случае пользовательский модуль можно определить в Haiku, расширив classhk.Module
, который аналогичен модулю PyTorch, расширяющему torch.nn.Module
.
Нам нужно реализовать только 2 метода класса: __init__
и __call__
, которые на самом деле являются прямой функцией в PyTorch. Ниже приведен пример, показанный на Haiku GitHub:
class MyLinear(hk.Module): def __init__(self, output_size, name=None): super().__init__(name=name) self.output_size = output_size def __call__(self, x): j, k = x.shape[-1], self.output_size w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j)) w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init) b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros) return jnp.dot(x, w) + b
Чтобы отделить состояние от модуля, каждый именованный параметр модели должен быть получен через hk.get_parameter
в проходе прямой подачи (__call__
). Причина отказа от использования свойств объекта для хранения параметров модели заключается в том, что функция может быть преобразована в чистую функцию JAX с использованием hk.transform
, как мы объясняли в начале.
В папке Haiku Git есть много примеров кода. Я могу легко изменить haiku_lstms.ipynb, чтобы создать предсказатель временных рядов для данных котировок акций.