Примечание редактора. Люк Мец будет спикером на ODSC East 2022 с 19 по 21 апреля. Обязательно ознакомьтесь с его выступлением «Learned Optimizers» здесь!

Поскольку модели машинного обучения продолжают расти, затраты и время на обучение таких моделей становятся все более громоздкими. Эти растущие затраты усложняют как обучение моделей на новых данных, так и проведение исследований для улучшения будущих версий этих моделей. Традиционно для обучения таких моделей используются разработанные вручную алгоритмы оптимизации, такие как стохастический градиентный спуск, или более сложные алгоритмы, такие как Адам. Этот пост посвящен опытным оптимизаторам, которые вместо того, чтобы полагаться на эту поддержку, изучают процедуру оптимизации, которая лучше всего соответствует цели, что приводит к более быстрой оптимизации!

Я считаю, что оптимизация обучения и, в более общем плане, метаобучение — это естественный следующий шаг к революции глубокого обучения. Точно так же изученные функции и вычисления заменили разработанные вручную функции, изученные алгоритмы и вычисления заменят разработанные вручную.

Я лично весьма воодушевлен этим направлением исследований, и последние несколько лет я занимаюсь этим направлением, обучая все больше и больше опытных оптимизаторов.

В этом посте я хотел дать представление о том, как выглядит обучение и использование обученного оптимизатора. Мы будем использовать новый пакет с открытым исходным кодом нашей команды learnedoptimization, который позволяет изучать оптимизатор и метаобучающие исследования в JAX — новой платформе машинного обучения Google.

Этот пост краток. Для более подробного ознакомления, включая начало работы в коллабах и примеры, смотрите нашу документацию.

Обучение опытного оптимизатора с помощью `learned_optimization`

Для начала давайте сначала установим Learn_optimization и импортируем модули, которые нам понадобятся.

!pip установить git+https://github.com/google/learned_optimization.git

import jax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as onp
import tqdm # For fancy progress bars
from colabtools import adhoc_import
from learned_optimization.tasks.fixed import conv
from learned_optimization.tasks import base as tasks_base
from learned_optimization import eval_training
from learned_optimization.optimizers import base as opt_base
from learned_optimization.learned_optimizers import adafac_mlp_lopt
from learned_optimization.outer_trainers import truncated_pes
from learned_optimization.outer_trainers import lopt_truncated_step
from learned_optimization.outer_trainers import truncated_grad
from learned_optimization.outer_trainers import gradient_learner
from learned_optimization.outer_trainers import truncation_schedule

Целевая задача

При создании обученного оптимизатора мы должны определить некоторую задачу или набор задач, с которыми этот обученный оптимизатор должен работать хорошо. Эта задача содержит архитектуру модели, набор данных и функцию потерь. Это может представлять любую проблему оптимизации, но в этом конкретном примере мы сосредоточимся на задачах оптимизации нейронной сети.

Задача, с которой мы будем работать, — это небольшой коннет, обученный на Cifar10. Эта сеть работает с пакетами изображений Cifar10, размер которых изменен до 16 × 16, и использует небольшую сеть из трех скрытых слоев для прогнозирования. Этот коннет уже реализован в `learned_optimization` и находится здесь.

Мы выбрали эту задачу, так как она чрезвычайно мала, и, следовательно, с ней можно быстро поэкспериментировать.

task = conv.Conv_Cifar10_16_32x64x64()

Мы можем инициализировать веса нейронной сети, выбрать пакет данных, вычислить потери или использовать JAX для вычисления градиентов этого.

key = jax.random.PRNGKey(0)
weights = task.init(key)
batch = next(task.datasets.train)
print("loss:", task.loss(weights, key, batch))
grad = jax.grad(task.loss)(weights, key, batch)
print("Gradient shapes:")
jax.tree_map(lambda x: x.shape, grad)
loss: 2.3102112
Gradient shapes:
{'conv2_d': {'b': (32,), 'w': (3, 3, 3, 32)},
 'conv2_d_1': {'b': (64,), 'w': (3, 3, 32, 64)},
 'conv2_d_2': {'b': (64,), 'w': (3, 3, 64, 64)},
 'linear': {'b': (10,), 'w': (64, 10)}}

Оптимизаторы

Прежде чем говорить об опытных оптимизаторах, давайте представим более стандартный, разработанный вручную интерфейс оптимизатора. Это функции, которые принимают значения градиента и создают некоторое новое состояние, содержащее новые значения параметров. Чтобы продемонстрировать, мы можем создать оптимизатор SGD и использовать его, чтобы сделать один шаг (`opt.update`) с искусственным градиентом.

opt = opt_base.SGD(0.1)
params = jnp.ones([3])
opt_state = opt.init(params)
grads = jnp.ones([3])
new_opt_state = opt.update(opt_state, grads)
opt.get_params(new_opt_state)

Обученные оптимизаторы

Обученный оптимизатор — это параметрический оптимизатор, а именно оптимизатор, который является функцией некоторого набора параметров. Можно инициализировать веса этого обученного оптимизатора и использовать эти веса для получения экземпляра оптимизатора, с помощью которого можно выполнять обновления.

Как и в случае с нейронными сетями, существует семейство различных обучаемых оптимизаторов, которые мы можем использовать. Архитектура обученного оптимизатора, которую мы будем использовать в этом посте, была представлена ​​в разделе Практические компромиссы между памятью, вычислениями и производительностью в обученных оптимизаторах. Он состоит из небольшой нейронной сети, которая применяется к каждому параметру.

lopt = adafac_mlp_lopt.AdafacMLPLOpt(hidden_size=32)

Мы можем случайным образом инициализировать набор весов и посмотреть на их структуру. Во-первых, есть только небольшое количество обучаемых параметров — всего 242. Мы также можем видеть, что большинство этих весов параметризуют нейронную сеть — w0 отображает от 39 функций до скрытого размера 4, w1 отображает от 4 до 4, а w2 карты на выходе.

lopt_weights = lopt.init(jax.random.PRNGKey(0))
shapes = jax.tree_map(lambda x: x.shape, lopt_weights)
num_params = sum(map(onp.prod, jax.tree_leaves(shapes)))
print("Total params:", num_params)
print("====")
shapes
Total params: 242
====
{'adafactor_decays': (3,),
 'momentum_decays': (3,),
 'nn': {'~': {'b0': (32,),
   'b1': (32,),
   'b2': (2,),
   'w0': (39, 32),
   'w1': (32, 32),
   'w2': (32, 2)}},
 'rms_decays': (1,)}

Мы можем использовать эти веса для создания экземпляра оптимизатора следующим образом.

opt = lopt.opt_fn(lopt_weights)

Затем этот оптимизатор можно использовать, как и раньше (но на этот раз с дополнительным аргументом потерь для обновления).

opt_state = opt.init(params, num_steps=10)
grads = jnp.ones([3])
loss = 1.0
new_opt_state = opt.update(opt_state, loss=loss, grad=grads)
new_params = opt.get_params(new_opt_state)

Однако в этот момент lopt_weights инициализируются совершенно случайным образом! Это не сделает хорошего оптимизатора без их обучения.

Базовые показатели

Прежде чем обучать нашего обученного оптимизатора, давайте сначала запустим некоторые базовые показатели.

Наша цель — попытаться обучить эту маленькую сеть быстрее, чем задачи, разработанные вручную. Таким образом, мы сравним с обучением этой задачи коннета, когда Адам ищет несколько разных скоростей обучения и для каждой скорости обучения 5 различных случайных инициализаций.

Чтобы упростить эту задачу, мы воспользуемся модулем Learn_optimization.eval_training и, в частности, функцией single_task_training_curves, которая выполняет итерацию в течение num_steps, каждый шаг вычисляет градиент и применяет предоставленный оптимизатор. В дополнение к обучению эта функция также оценивает производительность модели во время обучения.

key = jax.random.PRNGKey(0)
curves_for_lr = {}
for lr in [1e-4, 3e-4, 1e-3, 3e-3, 5e-3, 7e-3, 1e-2, 2e-2, 3e-2]:
  opt = opt_base.Adam(lr)
  curves_for_lr[lr] = []
  print(lr)
  for s in range(5):
    key1, key = jax.random.split(key)
    curves = eval_training.single_task_training_curves(task, opt,
                                                      num_steps=200, key=key1, eval_every=5,
                                                      eval_batches=10, last_eval_batches=30)
    curves_for_lr[lr].append(curves)

Теперь мы можем построить результаты. Слева мы видим кривые обучения для каждой скорости обучения. Справа мы видим среднюю достигнутую производительность (оранжевый) и производительность в конце обучения (синий) в зависимости от скорости обучения.

fig, axs = plt.subplots(1,2, figsize=(15, 5))
for lr, curves in curves_for_lr.items():
  x = curves[0]["eval/xs"]
  y = onp.mean([c["eval/train/loss"] for c in curves], axis=0)
  axs[0].plot(x, y)
 
axs[0].set_xlabel("training iteration")
axs[0].set_ylabel("training loss")

xs = []
ys = []
ys_mean = []
for lr, curves in curves_for_lr.items():
  last_value = onp.mean([c["eval/train/loss"] for c in curves], axis=0)[-1]
  xs.append(lr)
  ys.append(last_value)
  mean_value = onp.mean([c["eval/train/loss"] for c in curves])
  ys_mean.append(mean_value)
axs[1].semilogx(xs, ys, "o-", label="last loss")
axs[1].semilogx(xs, ys_mean, "o-", label="mean loss")
axs[1].set_xlabel("learning rate")
axs[1].set_ylabel("training loss")
axs[1].legend()

Из этого видно, что скорость обучения ~ 2e-3 — это примерно лучшее, что мы можем сделать, и мы можем достичь минимального значения потерь около ~ 1,75.

Обучение обученного оптимизатора

Обучение обученного оптимизатора влечет за собой повторное обучение внутренней проблемы (нашей маленькой сети) снова и снова. На каждой итерации мы оцениваем некоторый «метаградиент» — направление перемещения весов обученного оптимизатора, чтобы улучшить способность этого обученного оптимизатора оптимизировать эту задачу. Как и в случае со стандартным обучением на основе градиента, мы затем немного движемся в этом направлении и повторяем это снова и снова.

Learned_optimization поддерживает ряд различных способов оценки этого градиента, которые варьируются от конечной разности, вычисления градиентов с обратным распространением до более сложных методов, таких как Стратегии постоянного развития (PES) для оценки градиентов.

В этом примере мы будем использовать оценщик градиента PES, поскольку было продемонстрировано, что он хорошо работает для обучения обученных оптимизаторов. PES работает, пытаясь улучшить средние потери, которые обученный оптимизатор получает в ходе обучения.

Оценщики градиента Learn_optimization работают с объектами, называемыми TruncatedStep. Они инкапсулируют все детали, связанные с изученными оптимизаторами, и предоставляют простой интерфейс, так что одни и те же оценщики градиента могут использоваться для различных типов метаобучаемых систем, а не только для изученных оптимизаторов.

А пока мы создадим этот объект `TruncatedStep` для опытных оптимизаторов. На этом шаге мы указываем расписание усечения или то, сколько времени мы хотим, чтобы каждая внутренняя проблема занимала. Мы запустили наши базовые линии для 200 итераций, поэтому здесь мы будем использовать ту же длину.

Обучение обученного оптимизатора стоит дорого. Чтобы ускорить вычисления, мы используем векторизацию. В частности, мы используем наш обученный оптимизатор для обучения нескольких консетей одновременно с использованием аппаратного ускорителя. Мы указываем это с помощью аргумента `num_tasks`.

max_length = 200
trunc_sched = truncation_schedule.ConstantTruncationSchedule(max_length)
truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
    tasks_base.single_task_to_family(task), lopt, trunc_sched, num_tasks=32, random_initial_iteration_offset=max_length)

Наконец, мы можем построить оценщик градиента.

grad_estimator = truncated_pes.TruncatedPES(truncated_step=truncated_step,
trunc_length=10)

Далее мы указываем, как использовать эти градиенты для обновления весов обученного оптимизатора. Для этого мы будем использовать оптимизатор Адама с обрезанными градиентами. Эта комбинация оказалась успешной в прошлом, хотя скорость обучения часто необходимо искать.

outer_learning_rate = 3e-3
theta_opt = opt_base.GradientClipOptimizer(opt_base.Adam(outer_learning_rate))

Наконец, класс SingleMachineGradientLearner использует этот оценщик градиента и внешний оптимизатор (Адам).

gradient_estimators = [grad_estimator]
outer_trainer = gradient_learner.SingleMachineGradientLearner(
    lopt, gradient_estimators, theta_opt)

Наконец-то мы можем обучить веса обученного оптимизатора! Во-первых, мы инициализируем начальное состояние external_trainer.

# Initialize weights of learned optimizer + state of workers.
key = jax.random.PRNGKey(int(onp.random.randint(0, int(2**30))))
outer_trainer_state = outer_trainer.init(key)
all_losses = []
losses = []

Затем мы итерации обучаем веса обученного оптимизатора. Каждый шаг `outer_trainer.update` выполняет одно развертывание (длиной 20) и вычисляет оценку градиента для каждого отдельного экземпляра внутренней задачи, усредняет мета-градиенты и применяет Адама для обновления весов обученного оптимизатора.

Ради этого поста и для ускорения работы мы метаобучаем только 1000 итераций. Это должно занять около 10 минут на хорошем ускорителе — я запускал это на одном чипе TPUv3. Этого достаточно, чтобы превзойти базовые показатели (как мы увидим), но использование дополнительных вычислений почти всегда повышает производительность.

outer_iterations = 1000
for i in tqdm.trange(outer_iterations):
  key1, key = jax.random.split(key)
  outer_trainer_state, loss, metrics = outer_trainer.update(
      outer_trainer_state, key1)
  losses.append(loss)
  if i % 50 == 0:
    all_losses.append(onp.mean(losses))
    losses = []
    print(all_losses[-1])

Посмотрим, как мы это сделали! На следующем графике показана внешняя итерация (каждая итерация, с которой мы обновляем веса изученного оптимизатора) по сравнению с внешней потерей (измерение производительности изученного оптимизатора — насколько хорошо он оптимизирует).

Убытки снижаются — отлично!

plt.plot(onp.arange(len(all_losses))*50, all_losses,  "o-")
plt.xlabel("outer-iteration updates")
plt.ylabel("average loss of convnet inner-problem (loss from PES)")
plt.ylim(1.7, 2.3)

Оценка обученной модели

Теперь на приведенном выше графике показаны усредненные потери по каждой обучаемой консети. Это несколько абстрактное измерение, и мы действительно хотим видеть, что оно оптимизируется быстрее, чем базовые показатели. Чтобы показать это, мы оценим наш оптимизатор с той же функцией single_task_training_curves, которую мы использовали для наших базовых показателей.

Для этого нам сначала нужно создать экземпляр оптимизатора. Давайте сначала загрузим оптимизатор из оптимизатора весов, который мы нашли в предыдущем мета-обучении.

theta = outer_trainer.get_meta_params(outer_trainer_state)
opt = lopt.opt_fn(theta)

И тогда мы можем запустить трейнер.

key = jax.random.PRNGKey(1)
lopt_curves = []
for i in range(5):
  key1, key = jax.random.split(key)
  lopt_curves.append(eval_training.single_task_training_curves(task, opt,
                                                          num_steps=200, key=key1, eval_every=5,
                                                          eval_batches=20, last_eval_batches=30))

Наконец, мы можем построить результат. Мы видим, что наш обученный оптимизатор работает быстрее и достигает минимума.

fig, ax = plt.subplots(1,1, figsize=(8, 5))
for lr, curves in curves_for_lr.items():
  x = curves[0]["eval/xs"]
  y = onp.mean([c["eval/train/loss"] for c in curves], axis=0)
  ax.plot(x, y)
 
x = lopt_curves[0]["eval/xs"]
y = onp.mean([c["eval/train/loss"] for c in lopt_curves], axis=0)
ax.plot(x,y, color="k")
ax.set_xlabel("training iteration")
ax.set_ylabel("training loss")

Заключение

Я надеюсь, что этот пост дает краткий предварительный обзор того, как обучать опытного оптимизатора!

То, что мы показываем здесь, представляет собой довольно небольшой масштаб, способный работать внутри блокнота Colab (вот тот же пост в форме блокнота; для более быстрого обучения обязательно измените тип среды выполнения на использование GPU/TPU. Еще быстрее, получите экземпляр GCP.)

По мере того, как количество вычислений в мире растет, я с нетерпением жду возможности узнать, что сделают обученные оптимизаторы. Моя исследовательская программа состоит в том, чтобы работать над более универсальными обучаемыми оптимизаторами, обучая их широкому кругу задач. Мы опубликовали некоторые результаты [1][2][3], но работа продолжается! Если это руководство вызвало у вас интерес, попробуйте и обучите своего опытного оптимизатора! Кроме того, ознакомьтесь с моим докладом на предстоящей конференции ODSC East Изученные оптимизаторы.

Исходное сообщение здесь.

Читайте другие статьи по науке о данных на OpenDataScience.com, включая учебные пособия и руководства от начального до продвинутого уровня! Подпишитесь на нашу еженедельную рассылку здесь и получайте последние новости каждый четверг. Вы также можете пройти обучение по науке о данных по запросу, где бы вы ни находились, с нашей платформой Ai+ Training. Подпишитесь также на нашу быстрорастущую публикацию на Medium, ODSC Journal, и узнайте, как стать писателем.