Нейронные сети доминируют в области машинного обучения уже около десяти лет, и с каждым годом они становятся лучше. Учитывая важность медицинских данных и их широкое использование, вполне естественно видеть, что NN применяются для решения биомедицинских проблем. В этой статье мы собираемся использовать MobileNetV3 для решения проблемы классификации в медицинской сфере. Будьте на связи!

Примечание: код для этой проблемы доступен через этот блокнот Google Colab, это ядро ​​Kaggle и это репозиторий GitHub. Эта статья представляет собой краткое изложение того, что происходит за кодами. Пошаговый процесс описан в блокноте.

Вступление

По данным Википедии:

Пневмония - это воспалительное состояние легких, в первую очередь поражающее небольшие воздушные мешочки, известные как альвеолы. Симптомы обычно включают сочетание продуктивного или сухого кашля, боли в груди, лихорадки и затрудненного дыхания. Степень тяжести состояния варьируется.

Один из способов, который используют врачи, чтобы определить, страдают ли пациенты пневмонией или нет, - это анализировать их рентгеновские снимки. Используя информацию на этих изображениях, они могут увидеть, как пневмония повлияла на пациента, насколько она серьезна и какие части легких поражены ею больше всего. Вы можете увидеть пример на рисунке 1.

Набор данных

Набор данных, который мы используем для решения этой проблемы, называется Рентгеновские снимки грудной клетки (пневмония), первоначально представленный в этой статье. Он состоит из около 5 863 высококачественных рентгеновских снимков легких здоровых и больных пациентов, разного пола и возрастных групп, с указанием бактериальной или вирусной причины болезни для больных. Он разделен на 3 разных набора: обучение, тестирование и проверка. Подгруппа поездов имеет несбалансированное распределение, в ней преобладают пациенты с пневмонией (см. Рисунок 2). Набор данных также состоит из изображений разных размеров, что требует предварительной обработки изменения размера. (см. рисунок 3).

Предварительная обработка

Как я упоминал выше, размеры изображений различаются как по ширине, так и по высоте, без каких-либо уникальных соотношений сторон. Чтобы исправить это, мы должны выбрать базовый размер, до которого будут изменены все изображения, чтобы мы могли кормить сеть изображениями точно такого же размера. Взглянув на рисунок 3, мы можем сделать вывод, что использование размера, который меньше, чем у большинства изображений, но при этом позволяет сохранить детали изображений, является очень важным шагом. Чем больше размер, тем больше деталей у нас есть. Учитывая это, мы по-прежнему не можем сделать размеры входного изображения больше, чем заметная часть нашего набора данных. Следовательно, я выбрал 256x256. Одна из причин выбора этого заключается в том, что нам нужно выбрать размер изображения, кратный 32, из-за архитектуры MobileNetV3. Вы можете увидеть образец набора поездов с измененными размерами на рисунке 4.

Поскольку набор данных уже предварительно обработан и контролируется качество, дальнейшие манипуляции с данными не требуются. Рентгеновские изображения не имеют стандартного протокола размера, освещения, контраста, и дальнейшие манипуляции с данными могут внести в модель определенное неявное смещение. Мы используем увеличение данных на этапе обучения, чтобы сделать модель устойчивой к подобным вариациям во входной области.

Модель

Архитектура

Модель, которую мы используем в этом проекте, - MobileNetV3. MobileNetV3 - это чрезвычайно легкая CNN, подходящая для использования в сотовых телефонах с ограниченными ресурсами, что делает ее очень привлекательной для задач, не имеющих хорошей предварительно подготовленной базовой линии и зависящих от предметной области. Кроме того, мы используем такие ресурсы, как Google Colab или Kaggle, у которых есть ограничения по времени. Поэтому использование небольшой, но мощной CNN имеет еще больший смысл. Мы будем использовать мобильную сеть, предварительно обученную в сети изображений, и настроим ее последние слои. MobileNetV3, который мы используем, является уменьшенной версией, потому что мы стараемся минимизировать количество параметров, насколько это возможно. Вы можете увидеть архитектуру модели на рисунке 5.

Мы используем основу модели, которая представляет собой все слои перед слоем объединения 7x7, и в качестве следующего шага замораживаем все, кроме последнего слоя узкого места и следующего слоя conv2d в магистрали. Причина, по которой мы замораживаем слои до двух последних, заключается в том, что они изучили богатое общее представление входных изображений и уже обладают обширным обобщаемым пониманием входного пространства, которое может отображать входное пространство в очень сложное нелинейное пространство. . Последние уровни создают окончательные «решающие» особенности, которые обычно гораздо более специфичны для конкретной проблемы. Более того, мы настраиваем последние два уровня, а не сбрасываем их, используя гораздо меньшую скорость обучения для этих незамороженных слоев по сравнению с другими уровнями, которые мы представляем.

После магистрали мы используем слой Global Average Pooling Layer, который сопоставляет карты функций 8x8 с 1x1, за которым следует выпадение и линейный слой с двумя выходами, который использует активацию Softmax. Архитектура окончательной модели описана ниже (с использованием torch-summary):

Как видите, последний инвертированный остаточный слой и следующий за ним слой активации ConvBNA - единственные два слоя исходной магистрали, которые мы настраиваем.

Потеря

Из-за дисбаланса меток в наборе поездов мы используем взвешенную кросс-энтропийную потерю, чтобы получить более сбалансированную модель. Формула для расчета весов потерь:

weight[label] = n_total_samples/ (n_label_samples * num_labels)

Используя эту формулу, мы можем убедиться, что вес образцов класса меньшинства становится больше, чем выборки класса большинства, и, следовательно, мы можем изучить сбалансированное представление входных данных.

Обучение

Для обучения модели определяется класс тренера, который можно увидеть в блокноте. Все параметры тренировки можно найти в блокноте под заголовком «Константы». Примечательной особенностью класса тренера является использование ранней остановки в качестве схемы регуляризации, чтобы избежать переобучения, и использование только модели, которая имеет лучшую производительность на данных проверки, для окончательных тестовых данных.

Кроме того, мы используем оптимизатор AdamW со снижением веса 5e-2. Скорость обучения для линейного слоя выбрана как 1e-2, а скорость обучения для точной настройки последних слоев CNN выбрана как 1e-6. Еще одна важная вещь, на которую следует обратить внимание, это то, что мы используем схему косинусного отжига с теплыми перезапусками, чтобы снизить скорость обучения для обеих групп параметров. Продолжительность циклов также увеличивается после каждого перезапуска, что делает производительность обучения модели стабильной к концу фазы обучения. Вы можете увидеть историю обучения модели на рисунке 6.

Вы можете увидеть производительность модели на тестовом наборе в Таблице 2 и на Рисунке 7.

Как видно из результатов, модель с очень высокой степенью уверенности в обнаружении реальных случаев пневмонии имеет тенденцию классифицировать нормальных людей как больных. Хотя в большинстве случаев это может показаться не очень хорошим, поскольку важность классификации людей с заболеваниями гораздо важнее ошибочной классификации, на самом деле это неплохо в данной ситуации. Как видите, модель имеет очень высокую точность для нормального класса, а это означает, что она с высокой степенью уверенности говорит о том, что пациент здоров. Точность 88% и взвешенный показатель f1 87% показывают возможности этой облегченной модели.

Заключение

В этой статье мы использовали предварительно обученную MobileNetV3, очень легкую CNN, для классификации рентгеновских изображений пациентов с пневмонией и здоровых людей. Путем тонкой настройки последних уровней MobileNetV3 и использования схем ранней остановки и косинусного отжига для повышения способности модели к обучению мы разработали надежную модель, которая имеет очень хорошие характеристики, особенно производительность с поправкой на риски.