Начало работы с практическим руководством по графовой нейронной сети

В этом сообщении блога мы рассмотрим подробное руководство по обучению сверточной сети графа (GCN). Учебник содержит краткое объяснение идеи, лежащей в основе GCN, и реализацию построчного обучения в Tensorflow.

Давайте сначала посмотрим на набор данных, который мы будем использовать сегодня. Мы используем набор данных Cora. Это аналог MNIST в графовых нейронных сетях. Набор данных Cora содержит график сети цитирования. Каждый узел графа — это статья, а каждое ребро — ссылка. Есть 2708 узлов и 10556 ребер. Доклады относятся к 7 научной категории. С каждой статьей связан вектор признаков. Вектор признаков представляет собой логический массив из 4_ элементов, каждый из которых указывает, появляется ли в статье соответствующее словарное слово. Набор данных Cora также содержит обучающие, проверочные и тестовые маски (по 2708 логических элементов каждая), используемые для обозначения узлов, находящихся в соответствующих разбиениях.

Задачу можно сформулировать так: при заданном небольшом наборе известных меток (категории 140 статей в деле Cora) мы хотим предсказать метки для других статей в сети цитирования. Такая же схема распространена в задачах классификации узлов графа, таких как обнаружение мошенничества. Интуитивно понятно, что метки узлов могут быть получены из самих признаков узлов и их соединений с другими узлами (известны метки или нет).

Для сравнения, давайте сначала обучим модель, основанную исключительно на функциях узла. Мы знаем, что он не будет работать хорошо, потому что он не принимает во внимание связи между узлами. Тем не менее, мы сначала возьмем это за основу.

Мы используем библиотеку DGL для загрузки набора данных. Библиотека DGL представляет собой набор строительных блоков канонической графовой нейронной сети, которые мы будем использовать позже. Затем мы готовим набор данных для обучения, проверки и тестирования, используя соответствующие маски для функций и меток всего набора данных.

Мы создаем простую модель с двумя полносвязными слоями и выпадающим слоем и обучаем модель на 20 эпох.

Судя по рисунку 1 выше, точность обучения и проверки невелика. Мы оцениваем модель на тестовом наборе данных. Точность еще хуже, всего 17%.

Теперь пришло время изучить графовые сверточные сети. Идея была впервые сформулирована в этой бумаге. Проще говоря, модель учится использовать как функцию узла, так и его соседей для задачи классификации узлов. См. иллюстрацию на следующем рисунке 2.

Это сверточная нейронная сеть с двухслойным графом. Окончательное скрытое представление узла A зависит от предыдущего скрытого представления узлов B, C и узла A itself, которое исходит из первого скрытого слоя в модели. Те, в свою очередь, зависят от слоя 0, который представляет собой необработанные функции узла. Параметры W и b одинаковы внутри слоя и различаются между слоями. Механизм агрегации может быть простой средней или какой-то другой более сложной схемой. Агрегирование обычно нормализуется на основе количества исходящих и входящих ребер исходного и конечного узлов. Это имеет интуитивно понятный смысл, потому что «влияние» узла на принимающий узел должно зависеть от того, сколько исходящих соединений имеет этот узел и сколько входящих соединений имеет принимающий узел. Формула выглядит следующим образом.

N(i) — это множество соседей узла i. cij — термин нормализации. Обычно это квадратный корень из числа входящих соединений для узла i, умноженный на количество исходящих соединений для узла j.

Чем больше слоев мы добавляем, тем дальше модель может собирать информацию. Например, узел D не учитывался бы, если бы в модели был только один слой. Однако будьте осторожны, добавляя слишком много слоев. Бумага показывает, что производительность модели резко ухудшается, когда добавляется слишком много слоев, потому что в этом случае мы, по сути, сглаживаем информацию о метках по всему графику.

Мы используем строительный блок GraphConv из библиотеки DGL для построения нашей модели.

Мы также обучаем его на 20 эпох.

На Рисунке 3 выше точность выглядит намного лучше, даже несмотря на то, что она немного завышена. Точность теста 71%, намного выше, чем у предыдущего 17%. Так что, как и ожидалось, соединения в сети имеют сильную предсказательную силу, особенно если учесть, что на самом деле в обучающем наборе всего 140 узлов.