Набор инструментов визуализации функций с открытым исходным кодом для нейронных сетей в PyTorch

Настройка сцены

Пару недель назад я выступал с докладом на Hopperx1 London, организованном AnitaB.org в рамках London Tech Week. Колода слайдов доступна здесь.

Я получил такой положительный отзыв после выступления, что решил написать немного более длинную версию доклада, чтобы раньше познакомить FlashTorch с миром :)

Пакет доступен для установки через pip. Исходный код можно найти в репозитории GitHub. Вы также можете поиграть с ним в этом блокноте, размещенном в Google Colab, без необходимости устанавливать что-либо!

Но сначала я кратко ознакомлюсь с историей визуализации функций, чтобы вы могли лучше понять что и почему.

Введение в визуализацию функций

Визуализация функций - это активная область исследований, цель которых - понять, как нейронные сети воспринимают изображения, исследуя способы, с помощью которых мы можем смотреть их глазами. Он возник и развился в ответ на растущее желание сделать нейронные сети более понятными для людей.

Самые ранние работы включают анализ того, на что нейронные сети обращают внимание во входных изображениях. Например, карты значимости классов для конкретных изображений визуализируют области во входном изображении, которые вносят наибольший вклад в соответствующий выходной, путем вычисления градиента выходных данных класса с учетом к входному изображению через обратное распространение (подробнее о картах значимости позже в этом посте).

Еще одно направление техники визуализации признаков - максимизация активации. Это позволяет нам итеративно обновлять входное изображение (изначально созданное с некоторым случайным шумом) для создания изображения, которое максимально активирует целевой нейрон. Он дает некоторую интуицию относительно того, как отдельные нейроны реагируют на входные данные. Это метод так называемого Deep Dream, который популяризировал Google.

Это был огромный шаг вперед, но у него был недостаток в том, что он не дает достаточного понимания того, как работает вся сеть, поскольку нейроны не работают изолированно. Это привело к попытке визуализировать взаимодействия между нейронами. Olah et al. продемонстрировал арифметические свойства пространства активации путем добавления или интерполяции между двумя нейронами.

Затем Olah at al. пошел дальше, чтобы определить более значимую единицу визуализации, проанализировав количество выстрелов каждого нейтрона внутри скрытого слоя при заданном вводе. Визуализация группы нейронов, которые сильно активируются вместе показала, что существуют группы нейронов, ответственных за улавливание таких понятий, как висячие уши, пушистые ноги и трава.

Одной из последних разработок в этой области является Атлас активации (Картер и др., 2019). В этом исследовании авторы рассмотрели основной недостаток визуализации активации фильтров, который дает лишь ограниченное представление о том, как сеть реагирует на одиночный ввод. Чтобы увидеть полную картину того, как сеть воспринимает множество объектов и как эти объекты связаны друг с другом в мировоззрении сети, они разработали способ создания глобальной карты, видимой сквозь глаз сети , показывая общие комбинации нейронов.

Мотивация FlashTorch

Когда я открыл для себя мир визуализации функций, меня сразу же привлекли его возможности для создания более интерпретируемых и объяснимых нейронных сетей. Затем я быстро понял, что не существует инструмента, позволяющего легко применить эти методы к нейронным сетям, которые я построил в PyTorch.

Поэтому я решил создать один - FlashTorch, который теперь доступен для установки через pip! Первый метод визуализации функций, который я реализовал, - это карты значимости.

Ниже мы рассмотрим более подробно, что такое карты значимости, а также то, как использовать FlashTorch для их реализации в нейронных сетях.

Карты значимости

Заметность в визуальном восприятии человека - это субъективное качество, которое выделяет определенные объекты в поле зрения и привлекает наше внимание. Карты значимости в компьютерном зрении могут указывать на наиболее заметные области на изображениях.

Метод создания карт значимости из сверточных нейронных сетей (CNN) был впервые представлен в 2013 году в статье Глубокие сверточные сети: визуализация моделей классификации изображений и карт значимости. Авторы сообщили, что, вычисляя градиенты целевого класса по отношению к входному изображению, мы можем визуализировать области во входном изображении, которые имеют влияние на значение прогноза. этого класса.

Карты значимости с использованием FlashTorch

Без лишних слов, давайте использовать FlashTorch и сами визуализировать карты значимости!

FlashTorch поставляется с некоторыми utils функциями, которые также немного упрощают обработку данных. Мы собираемся использовать это изображение great grey owl в качестве примера.

Затем мы применим к изображению некоторые преобразования, чтобы сделать его форму, тип и значения подходящими в качестве входных данных для CNN.

Я собираюсь использовать AlexNet, который был предварительно обучен с ImageNet набором данных классификации для этой визуализации. Фактически, FlashTorch поддерживает все модели, которые поставляются с torchvision из коробки, поэтому я рекомендую вам попробовать и другие модели!

КлассBackprop - это ядро ​​ для создания карт значимости.

При создании он принимает модель Backprop(model) и регистрирует пользовательские привязки к интересующим слоям в сети, чтобы мы могли извлечь промежуточные градиенты из вычислительного графа для визуализации. . Эти промежуточные градиенты доступны не сразу из-за того, как PyTorch разработан. FlashTorch разбирается в этом за вас :)

И последнее, что нам нужно перед вычислением градиентов - индекс целевого класса.

Напомним, что нас интересуют градиенты целевого класса по отношению к входному изображению. Однако модель предварительно обучена с использованием набора данных ImageNet, и поэтому ее прогноз предоставляется как распределение вероятностей для 1000 классов. Мы хотим точно определить значение целевого класса (в нашем случае great grey owl) из этих 1000 значений, чтобы избежать ненужных вычислений и сосредоточиться только на взаимосвязи между входным изображением и целевым классом.

Для этого я также реализовал класс ImageNetIndex. Если вы не хотите загружать весь набор данных, а просто хотите узнать индексы классов на основе имен классов, это удобный инструмент. Если вы дадите ему имя класса, он найдет соответствующий индекс класса target_class = imagenet['great grey owl']. Если вы действительно хотите загрузить набор данных, используйте класс ImageNet, представленный в последней версии torchvision==0.3.0.

Теперь у нас есть входное изображение и индекс целевого класса (24), поэтому мы готовы вычислять градиенты!

Эти две строки являются ключевыми:

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

По умолчанию градиенты рассчитываются для каждого цветового канала, поэтому его форма будет такой же, как у входного изображения - в нашем случае (3, 224, 224). Иногда легче визуализировать градиенты, если мы берем максимальное количество градиентов по цветовым каналам. Мы можем сделать это, передав take_max=True в вызов метода. Форма градиентов будет (1, 224, 224).

Наконец, давайте визуализируем, что у нас есть!

Мы понимаем, что пиксели в области, где находится животное, оказывают сильнейшее влияние на значение прогноза.

Но это немного шумно ... сигнал распространяется, и он мало что говорит нам о восприятии совы нейронной сетью.

Есть ли способ улучшить это?

На помощь приходит управляемая обратная связь

Ответ положительный!

В статье Стремление к простоте: вся сверточная сеть авторы представили изобретательный способ уменьшения шума при вычислении градиентов.

По сути, в управляемом обратном распространении нейроны, не влияющие или не влияющие на прогнозируемое значение целевого класса, маскируются и игнорируются. Таким образом мы можем предотвратить поток градиентов через такие нейроны, что приведет к снижению шума.

Вы можете использовать управляемое обратное распространение в FlashTorch, передав guided=True в вызов метода calculate_gradients, например:

Давайте визуализируем управляемые градиенты.

Разница разительная!

Теперь мы ясно видим, что сеть обращает внимание на запавшие глаза и круглую голову совы. Эти характеристики «убедили» сеть классифицировать объект как great grey owl.

Но он не всегда фокусируется на глазах или головах ...

Как видите, сеть научилась сосредотачиваться на чертах, которые в значительной степени соответствуют тому, что мы считаем наиболее отличительными качествами этих птиц.

Приложения визуализации функций

С помощью визуализации функций мы не только можем лучше понять, что нейронная сеть узнала об объектах, но и лучше подготовлены, чтобы:

  • Диагностируйте, что в сети происходит не так и почему
  • Выявление и устранение ошибок в алгоритмах
  • Сделайте шаг вперед, глядя только на точность
  • Понять, почему сеть ведет себя именно так
  • Выяснить механизмы того, как нейронные сети обучаются

Воспользуйтесь FlashTorch сегодня!

Если у вас есть проекты, которые используют CNN в PyTorch, FlashTorch может помочь вам сделать ваши проекты более интерпретируемыми и объяснимыми.

Пожалуйста, дайте мне знать, что вы думаете, если воспользуетесь им! Буду очень признателен за конструктивные комментарии, отзывы и предложения 🙏

Спасибо, удачного кодирования!