Извлечение знаний - это распространенный способ обучения сжатых моделей путем переноса знаний, извлеченных из большой модели, в меньшую модель. Сегодня мы рассмотрим использование дистилляции знаний для обучения модели, которая выявляет пневмонию с помощью рентгеновских снимков грудной клетки.

Какой в ​​этом смысл?

Допустим, вы работаете в сфере, где конфиденциальность чрезвычайно важна, например в сфере здравоохранения. Может случиться так, что мы не сможем отправить данные о пациентах в облако, где находятся все наши модные графические процессоры. Это создает потребность в обучении модели, которую можно загрузить и запустить на маломощной машине.

Итак ... с чем мы работаем?

Давайте воспользуемся набором данных Рентгеновские снимки грудной клетки от Kaggle. Задача - выявить пневмонию на рентгенограмме грудной клетки.

Пневмония - это легочная инфекция, вызывающая кашель, жар, озноб и затрудненное дыхание. Это вызвано иммунным ответом на какую-то инфекцию легких, вызванную вирусом или бактериями. Текущий вирус COVID-19 может вызвать пневмонию.

По сути, в ваших легких есть воздушные мешочки, в которых кислород и углекислый газ обмениваются, чтобы вы могли дышать. Когда эти воздушные мешочки заражены вирусом или бактериями, ваше тело вырабатывает иммунный ответ, восполняя эту область жидкостью.

Большинство людей могут вылечиться от этого, но у некоторых это может привести к смерти из-за дыхательной недостаточности.

Пневмония убивает гораздо больше людей в развивающихся странах. В то время как в США в 2017 году от пневмонии умерло 50 000 человек, во всем мире от пневмонии умерло 3 миллиона человек.

Ограниченный доступ к инфраструктуре здравоохранения мотивирует потребность в технологиях для снижения затрат и повышения эффективности.

Давайте посмотрим на данные

Имея это в виду, на рентгеновском снимке грудной клетки мы будем искать мутные области, которые указывают на наличие жидкости в легких.

Вот нормальный рентген грудной клетки

Вот рентген грудной клетки пациента с пневмонией.

Не так очевидно, почему одно сканирование исправно, а другое заражено. Судя по тому, что я исследовал по этому поводу, врачи ищут белые сгустки по периферии легких.

Итак, давайте смоделируем это!

Если вы хотите сразу перейти к коду, перейдите сюда.

Самое простое, что мы можем сделать здесь, - это просто добавить это в предварительно обученную сверточную модель ResNet и посмотреть, как далеко мы можем продвинуться.

Мы будем использовать PyTorch и PyTorch lightning для построения и обучения моделей.

PyTorch Lightning - это библиотека, которая позволит нам модулировать наш код, чтобы мы могли разделить биты, которые являются общими практически для всех задач классификации изображений, и биты, специфичные для задач дистилляции изображений.

Давайте начнем с создания общего класса BaseImageClassificationTask, который позаботится обо всех скучных задачах классификации изображений, таких как настройка оптимизаторов и загрузка наборов данных. Смотрите код здесь и загрузку набора данных здесь.

Теперь давайте создадим простой ImageClassificationTask, который может использовать любую модель классификации изображений PyTorch и вычислить потерю перекрестной энтропии. Это позволяет нам подключать любую модель PyTorch, которая может использовать изображение и выводить прогноз.

Волшебным образом (не совсем) теперь мы можем запустить цикл обучения. PyTorch Lightning позаботится о выборке из загрузчиков данных и обратном распространении потерь.

Вот результаты тренировки с ResNet-18 после 40 эпох:

Окончательная точность набора тестов: 91%

Насколько «маленькой» можно сделать эту модель?

Помните, что первоначальная цель заключалась в создании моделей, которые можно было бы загрузить и запустить на маломощных машинах. В этом случае давайте создадим простую трехуровневую CNN в качестве модели студента.

Мы можем измерить размер этой модели двумя способами:

  1. Размер модели, который соответствует количеству параметров.
  2. Скорость модели, которая обычно выражается в количестве слоев.

Размер

Модель ResNet-18 имеет 11,7 млн ​​параметров, тогда как трехуровневая CNN имеет 277000 параметров. Это уменьшение параметров модели на 97,5%.

Скорость

Вывод процессора с ResNet-18 занимает 45 мс, тогда как трехуровневый CNN занимает 3 мс. Это 15-кратное увеличение скорости вывода.

Действительно ли нам нужен учитель?

Первый вопрос, который мы должны задать: действительно ли нам нужна модель учителя? Давайте наивно возьмем модель нашего ученика и обучим ее с ImageClassificationTask, как мы это делали с моделью ResNet.

Вот результаты после 40 эпох:

Точность тестового набора: 72%

Дистилляция

Теперь давайте создадим наш ImageClassificationDistillationTask класс.

Единственная значимая разница между ImageClassificationTask и ImageClassificationDistillationTask заключается в том, как вычисляются окончательные потери при обучении, а также в некоторых гиперпараметрах для настройки потерь.

1. Начиная с сети обученных учителей и сети неподготовленных студентов
(мы уже сделали это с помощью ResNet-18 выше)

2. Вперед пройти через модель учителя и получить логиты
Убедитесь, что вы перевели модель учителя в тестовый режим, чтобы мы не собирали без необходимости градиенты.

3. Вычислите окончательную потерю как потерю при перегонке + потерю классификации

4. Обратные потери через студенческую модель

Как работает функция потерь?

Функция потерь представляет собой взвешенную сумму двух вещей:

  • Нормальная потеря классификации, обозначаемая как student_target_loss в сущности.
  • Перекрестная потеря энтропии между логитами учеников и учителями, в сущности обозначаемая как distillation_loss.

Потеря обычно выражается в литературе следующим образом:

Перекрестная потеря энтропии между учеником и учителем - главное нововведение. Интуитивно это обучает ученика неуверенности учителя. Это также обычно называют потерями при перегонке. Интуитивно это делается для того, чтобы научить ученика тому, как учитель «думает». В дополнение к обучению ученика метке наземной истины, мы также обучаем ученика неопределенности метки, которую усвоил учитель.

Если учитель прогнозирует 51% пневмонии и 49% не пневмонию, мы также хотим, чтобы учащийся был в равной степени неуверен.

Это мотивирует необходимость двух параметров для корректировки поведения этой потери:

  • Альфа: сколько веса мы придаем потере ученика и учителя по сравнению с обычной потерей классификации.
  • Температура: насколько мы масштабируем неопределенность модели учителя.

Альфа

Параметр альфа контролирует вес потерь при перегонке. Значение альфа, равное 1, означает, что мы учитываем только потери при перегонке, а значение альфа, равное 0, означает, что мы полностью игнорируем потери при перегонке.

Температура

Температура - более интересный параметр, который определяет степень «неопределенности» прогнозов учителя.

Вот пример модели, которая выводит 3 класса:

Вот как прогнозы масштабируются с различными значениями температуры, чтобы масштабировать неопределенность этих прогнозов.

Назначение температурного параметра - контролировать степень неопределенности прогнозов учителя.

Какие гиперпараметры работают лучше всего?

Вот окончательные результаты для трехуровневой модели CNN студента с различными настройками гиперпараметров:

Наилучшим показателем на сегодняшний день является альфа = 0,25, температура = 1, что составляет 86% на тестовой выборке. Это улучшение по сравнению с исходными 72%, когда мы только обучили модель студента с нуля, без дистилляции.

Вот окончательные результаты:

В итоге

Нам удалось обучить модель, которая на 97,5% меньше и в 15 раз быстрее, чем ResNet-18, и примерно на 5% хуже, чем модель учителя.