Время адаптивных вычислений (ACT) в нейронных сетях [1/3]

Часть 1: АКТ в РНС

Есть интересная малоизвестная тема Adaptive Computing Time (ACT) в нейронных сетях. Это применимо к различным типам нейронных сетей (RNN, ResNet, Transformer), и вы можете использовать эту довольно общую идею и в другом месте.

Общая идея состоит в том, что некоторые сложные данные могут потребовать большего количества вычислений для получения окончательного результата, в то время как некоторые простые или неважные данные могут потребовать меньше. И вы можете динамически решать, как долго обрабатывать данные, обучая нейронную сеть автоматически адаптироваться к данным.

Оригинальная работа Алекса Грейвса:
Время адаптивных вычислений для рекуррентных нейронных сетей,
https://arxiv.org/abs/1603.08983

Кстати, я бы порекомендовал вам следить за статьями Алекса Грейвса, он является автором многих других интересных вещей, например многомерные RNN, CTC, Grid-LSTM, он работал над NTM и DNC и так далее.

Итак, начнем.

№1. РНН с АКТ

Классическая RNN обновляет свое состояние sₜ для каждого элемента последовательности xₜ за один шаг:

Итак, граф вычислений RNN выглядит так:

Adaptive Computing Time (ACT) изменяет стандартную настройку, позволяя сети выполнять переменное количество переходов между состояниями и вычислять переменное количество выходов на каждом шаге ввода:

Теперь каждое состояние и каждый вывод расширяются до переменного числа промежуточных обновлений / выводов, и результирующий вывод будет взвешенной суммой этих промежуточных выводов. Это чем-то напоминает механизм мягкого внимания.

Модифицированный граф вычислений выглядит следующим образом (имейте в виду, h здесь не является скрытым состоянием модуля, h - это выходной сигнал останавливающегося модуля, подробнее об этом ниже):

Интересно, сколько обновлений выполнять на каждом этапе?

Мы вводим блок остановки (по сути, сигмоидальный блок), роль которого состоит в том, чтобы решить, следует ли нам остановить вычисления или продолжить.

Обработка останавливается, когда сумма выходных сигналов блока остановки (h) становится близкой к 1,0 (фактически, 1,0 минус эпсилон, где эпсилон - небольшой значение, выбранное в статье равным 0,01). Затем последний h заменяется остатком, чтобы сумма (hᵢ) была равна 1.0. Эта процедура дает нам вероятности остановки (pᵢ) промежуточных шагов. Затем мы определяем обновления среднего поля для состояний и выходов:

Так, например, пусть выходы блока остановки равны 0,1, 0,3, 0,4 и 0,4 на четырех последовательных шагах. Сумма этих h больше, чем 1-эпсилон, поэтому мы создаем список весов: [0,1, 0,3, 0,4, 0,2] (последнее значение было усечено) и вычисляем конечное состояние и выводим как взвешенную сумму с этими весами. :

s = 0.1*s¹ + 0.3*s² + 0.4*s³ + 0.2*s⁴

y = 0.1*y¹ + 0.3*y² + 0.4*y³ + 0.2*y⁴

Нам нужно ограничить время вычислений, иначе у сети будет стимул обрабатывать данные как можно дольше. Итак, мы добавляем стоимость размышления к нашей функции потерь. Стоимость размышления напоминает общее вычисление во время последовательности. Он добавляется к общей потере с весом τ гиперпараметром штрафа по времени.

Стоимость ponder в некоторых точках прерывается, но на практике мы просто игнорируем это, поэтому градиент стоимости ponder относительно остановленных активаций очевиден, и сеть может быть дифференцирована как обычно и обучена с использованием обратного распространения.

Идея была протестирована на наборе задач, специально разработанных так, чтобы было сложно без такого механизма: определение четности последовательности двоичных чисел (последовательность передается как вектор, все элементы сразу, а не как последовательность элементов, что обычно для обучения RNN это важно), решение тестов логических задач, добавление элементов последовательности, дающих совокупные суммы, и, наконец, сортировка последовательностей от 2 до 15 чисел.

Сети ACT дают значительно лучшие результаты, чем базовые (те же RNN или LSTM, но без ACT), но требуют тщательной настройки гиперпараметра штрафа времени (авторы использовали поиск по сетке).

Еще один интересный эксперимент - предсказание персонажей в Википедии. Он раскрывает интригующий образец распределения размышлений при обработке последовательности. Сети предсказания символов, обученные с помощью ACT, постоянно делают паузу в пробелах между словами и делают паузу дольше на «граничных» символах, таких как запятые и точки.

Итак, выглядит интересно.

# 1б. Повторить-РНН

Затем на ICLR 2018 появилась интересная статья:

Сравнение фиксированного и адаптивного времени вычислений для рекуррентных нейронных сетей Даниэля Фоджо, Виктора Кампоса, Ксавьера Джиро-и-Нието,
https://arxiv.org/abs/1803.08165

Идея проста.

Что, если важнее всего не динамически прогнозируемое количество шагов, а возможность выполнить несколько шагов для каждого элемента?

Авторы разработали простой базовый план под названием Repeat-RNN, который выполняет фиксированное (›1) количество шагов для каждого элемента. Это своего рода абляция оригинального ACT, где количество шагов теперь является гиперпараметром.

Этот параметр эквивалентен расширению входной последовательности, повторение каждого элемента фиксированное количество раз плюс пометка его дополнительным логическим флагом, чтобы отличить новый входной токен от дублированного.

Неожиданным результатом стало то, что Repeat-RNN смогла решать задачи, используя меньшее количество шагов SGD и меньшее количество повторов, чем ACT. Удивительно, но Repeat-RNN работал так же хорошо или лучше, чем ACT в рассмотренных
двух задачах (четность и сложение).

И ACT, и Repeat-RNN требуют настройки гиперпараметров (штраф по времени для первого и количество повторений для второго). И то, и другое зависит от задачи, но количество повторений гораздо более интуитивно понятно, чем штраф по времени.

По-прежнему остается открытым вопрос, почему многократное обновление состояния каждого входного токена перед переходом к следующему улучшает обучаемость сети.

Одна из возможностей, упоминаемых авторами, - это связь между LSTM и остаточными сетями, а также недавний вывод о том, что последние выполняют итеративную оценку, предполагают, что повторяющиеся входные данные в RNN могут стимулировать итеративную оценку характеристик и создавать более эффективные модели.

Другая возможная причина заключается в том, что увеличенное количество применяемых нелинейностей позволяет сети моделировать более сложные функции с тем же количеством параметров.

… продолжение следует…

Часть 2 находится здесь.
Часть 3 находится здесь.