В этом посте мы рассмотрим, как написать модель корреляции текста с помощью рекуррентной нейронной сети (RNN) в Tensorflow. Весь код, обсуждаемый в этом посте, также доступен в виде самостоятельных и подробных блокнотов Jupyter. Этот пост можно напрямую использовать для реализации модели сиамской сети для классификации следующего высказывания.

Моделирование корреляции текста

Начнем с простой идеи. У нас есть список пар чисел (скаляров) (c, r). У нас также есть флаг (или метка истинность) для каждой пары, который равен 1 или 0.flag=1 означает, что пара коррелирована, а flag=0 означает, что пара не . Теперь мы хотим изучить функцию f (c, r), которая возвращает 1, если пара (c, r) коррелирована , и 0 в противном случае. Как мы можем узнать f?

Что ж, мы можем получить оценку, изучив корреляцию вес M. Мы хотим, чтобы на выходе был 0 или 1, поэтому мы можем дополнительно применить сигмоид. Таким образом:

Хорошо, как мы можем узнать f, когда c и r являются (d -мерными) векторами ? Ну, теперь M будет dxd матрицей. Таким образом:

Теперь рассмотрим ту же проблему, когда у нас есть список пар предложений (context, response). Теперь мы хотим узнать, связаны ли context и response. Как мы можем узнать f, когда входные данные являются предложениями, а не векторами?

Кажется, ответ лежит где-то в преобразовании текста в вектор. Давайте внимательно посмотрим на ввод. Ввод - это предложение, которое представляет собой последовательность слов (токенов). Таким образом, если бы мы могли преобразовать эту последовательность слов в вектор фиксированной длины, обсуждаемая выше модель корреляции сработала бы.

Мы знаем, как преобразовать слово в вектор с помощью word2vec. Как теперь преобразовать последовательность векторов в один вектор? Что ж, мы можем их усреднить. Но эта модель не узнает, какие слова важны во входных данных для изучения корреляции! Итак, какое решение?

Введите рекуррентную нейронную сеть (RNN).

RNN обрабатывает одиночный ввод (текущий вектор слова) для обновления своего внутреннего состояния и возвращает вывод для каждого введенного ввода. Конечное состояние или результат можно рассматривать как вектор для всего предложения. Эта модель также известна как Encoder и используется в нейронном машинном переводе (NMT). Вот несколько отличных постов о RNN и его вариантах.

Двойной кодировщик или сиамская сеть еще больше упрощает эту концепцию за счет использования одного и того же RNN (что означает одинаковые веса RNN) для преобразования context в вектор c и response в вектор r.

Для полноты мы должны уточнить, что мы не будем фактически использовать ванильные RNN, поскольку они не подходят для изучения долгосрочных зависимостей. Мы будем использовать LSTM, хотя можно использовать и GRU.

Покажите мне код!

Итак, у нас есть некоторое представление о том, как можно смоделировать корреляцию текста с помощью RNN. Теперь давайте посмотрим на конкретный код. При кодировании в тензорном потоке я считаю, что написание кода в Jupyter Notebooks - лучшее использование моего времени. Убедившись, что код в блокноте работает так, как задумано, я переношу его в класс или функцию python. Таким образом, весь код здесь доступен в гораздо более подробном и самостоятельном виде в виде блокнотов. Не стесняйтесь загружать ноутбук, вносить в него изменения и запускать в своей системе.

  1. Подготовка данных: Блокнот

Будем работать с Ubuntu Dialog Corpus v2. Блокнот описывает (мучительно) подробно, как получить и подготовить данные. Если шаги, упомянутые в записной книжке, кажутся медленными, вы можете использовать свой любимый метод для предварительной обработки данных. В конце мы хотим:

  • Учебные файлы: train.context, train.response и train.flag
  • Примеры файлов проверки: valid.context, valid.response и valid.flag
  • Файл словарного запаса: vocab.txt (по одному слову в строке), созданный из train.context и train.response. Первое слово - UNK. Все остальные слова встречаются не менее 50 раз.

2. Кормление данных: Блокнот.

Мы будем использовать официально рекомендованный API набора данных для подачи текстовых данных в нашу модель тензорного потока. Мы создаем итератор набора данных, который возвращает следующие поля:

  • init: инициализатор для итератора
  • контекст: [batch_size,?]; Пакет контекстных предложений
  • len_context: [batch_size, 1]; Длина каждого контекста
  • ответ: [размер_пачки,?]; Пакет ответных предложений
  • флаг: [размер_пакета, 1]; Значения флагов 1.0 или 0.0

3. Модель корреляции текста с использованием сиамской сети: Блокнот

Итак, мы разобрались с частью кода, связанной с подготовкой данных и итератором. Перейдем к основной модели !

  • Вложения слов. Нам нужна переменная, которая возвращает вектор для каждого индекса слова. Итак, если размер нашего словаря равен V, а размер вектора d, мы можем создать эту переменную следующим образом:
  • Получение векторов слов для каждого слова в контексте: Теперь мы хотим получить векторы слов для каждого слова в контекстном предложении. Однако помните, что этот контекст теперь представляет собой набор предложений.
  • Вектор контекста. Затем мы будем использовать RNN для преобразования контекста в вектор фиксированной длины.
  • Вектор ответа: мы будем повторно использовать веса RNN из контекста.
  • Веса корреляции M: Мы создаем веса корреляции и далее вычисляем логиты. Мы будем напрямую использовать логиты в tf.nn.sigmoid_cross_entropy_with_logits
  • Шаг обучения: мы также применяем градиентную обрезку.
  • Выполните 100 шагов:

4. Полная модель: Блокнот.

До сих пор мы видели, как мы можем подготовить данные, написать итератор набора данных и написать модель обучения в Tensorflow.

Чтобы проверить, обучается ли ваша модель, вы сравните потерю с удерживаемым набором. В данном случае наш валидационный набор.

Рекомендуется создавать разные модели для обучения и проверки. Этого можно добиться, создав отдельный график и сеанс для обучения и проверки. Например, мы можем создать обучающую модель следующим образом:

Аналогичным образом создайте модель проверки, создав отдельный график и сеанс следующим образом:

Теперь мы можем оценивать модель каждые k шагов (здесь k = 10), загружая сохраненный сеанс из контрольной точки в модель проверки.