Глубокое обучение с несбалансированными данными о классе

Классовый дисбаланс:

В машинном обучении мы иногда имеем дело с очень хорошими данными, такими как модные данные MNIST или данные CIFAR-10, где примеры каждого класса в наборе данных хорошо сбалансированы. Что произойдет, если в задаче классификации распределение примеров по известным классам смещено или искажено? Такие проблемы с серьезным или небольшим смещением в наборе данных являются обычными, и сегодня мы обсудим подход к обработке таких несбалансированных данных по классам. Давайте рассмотрим крайний случай несбалансированного набора данных писем и построим классификатор для обнаружения спамовых писем. Поскольку спам-сообщения встречаются относительно реже, предположим, что 5% всех писем - это спам. Если мы просто напишем простой однострочный код как -

def detectspam(mail-data):
 return ‘not spam’ 

Это даст нам правильный ответ в 95% случаев, и даже если это чрезмерное преувеличение, проблема у вас есть. Что наиболее важно, обучение любой модели с этими данными приведет к высокому достоверному предсказанию обычных писем, и из-за крайне малого количества спамовых писем в обучающих данных модель, скорее всего, не научится правильно предсказывать спам-письма. Вот почему точность, отзывчивость, оценка F1, кривые ROC / AUC - важные показатели, которые действительно рассказывают нам историю. Как вы уже догадались, один из способов уменьшить эту проблему - выполнить выборку, чтобы сбалансировать набор данных, чтобы классы были сбалансированы. Есть несколько других способов решения проблемы несбалансированности классов в машинном обучении, и Джейсон Браунли подготовил отличный всеобъемлющий обзор, проверьте его здесь.

Классовый дисбаланс в компьютерном зрении:

В случае проблемы компьютерного зрения эта проблема дисбаланса классов может быть более критичной, и здесь мы обсуждаем, как авторы подошли к задачам обнаружения объектов, которые приводят к развитию потери фокуса. В случае алгоритмов типа Fast R-CNN, сначала мы запускаем изображение через ConvNet для получения карты характеристик, а затем выполняется предложение региона (обычно около 2К регионов) на карте характеристик с высоким разрешением. Это двухступенчатые детекторы, и когда была представлена ​​статья о Focal Loss, возник интригующий вопрос: может ли одноступенчатый детектор, такой как YOLO или SSD, иметь такую ​​же точность, как двухступенчатые детекторы? Одноступенчатые детекторы были быстрыми, но точность в то время составляла около 10–40% от двухкаскадных детекторов. Авторы предположили, что несбалансированность классов во время обучения является основным препятствием, которое не позволяет одноступенчатым детекторам достичь той же точности, что и двухступенчатые детекторы.

Пример такого классового дисбаланса показан на рисунке 1, который не требует пояснений, который взят из самой презентации первоначальных авторов. Они обнаружили, что одноступенчатые детекторы работают лучше, когда имеется большее количество ограничивающих рамок, покрывающих пространство возможных объектов. Но этот подход вызвал серьезную проблему, поскольку данные переднего плана и фоновые данные распределяются неравномерно. Например, если мы рассмотрим 20000 ограничивающих рамок, в большинстве случаев 7–10 из них будут фактически содержать любую информацию об объекте, а остальные будут содержать фон, и в большинстве случаев их будет легко классифицировать, но они неинформативны. Здесь авторы обнаружили, что функция потерь (например, кросс-энтропия) является основной причиной того, что простые примеры будут отвлекать обучение. Ниже представлено наглядное изображение.

Несмотря на то, что неправильно классифицированные образцы подвергаются большему штрафу (красная стрелка на рис.1), чем правильные (зеленая стрелка), в настройках обнаружения плотных объектов из-за несбалансированного размера выборки функция потерь перегружена фоном (простые образцы ). Focal Loss решает эту проблему и спроектирован таким образом, чтобы уменьшить потери («уменьшение веса») для простых примеров, и, таким образом, сеть может сосредоточиться на обучении жестких примеров. Ниже приводится определение потери очага -

В потерях фокуса есть модулирующий фактор, умноженный на потерю кросс-энтропии. Если образец неправильно классифицирован, p (который представляет собой оценочную вероятность модели для класса с меткой y = 1) будет низким, а коэффициент модуляции близок к 1, и потери не затронуты. При p → 1 коэффициент модуляции приближается к 0, и потери для хорошо классифицированных примеров взвешиваются с понижением. Влияние параметра γ показано на графике ниже -

Цитата из статьи -

Коэффициент модуляции снижает вклад потерь из простых примеров и расширяет диапазон, в котором пример принимает низкие потери.

Чтобы понять это, мы сравним потери кросс-энтропии (CE) и фокальные потери, используя приведенное выше определение с γ = 2. Рассмотрим истинное значение 1,0 и 3 прогнозных значения: 0,90 (близко), 0,95 (очень близко), 0,20 (далеко). от истины). Давайте посмотрим на значения потерь ниже, используя TensorFlow:

CE loss when pred is close to true:  0.10536041
CE loss when pred is very close to true:  0.051293183
CE loss when pred is far from true:  1.6094373


focal loss when pred is close to true:  0.0010536041110754007
focal loss when pred is very close to true:  0.00012823295779526255
focal loss when pred is far from true:  1.0300399017333985

Здесь мы видим, что по сравнению с потерями CE, модулирующий фактор в потерях фокуса играет важную роль. Когда предсказание близко к истине, потеря наказывается сильнее, чем когда она далека. Важно отметить, что при прогнозе 0,90 потери в фокусе будут 0,01 × CE потери, но при прогнозе 0,95 потери в фокусе будут примерно 0,002 × CE потери. Теперь мы видим, как потеря фокуса снижает вклад потерь из простых примеров и расширяет диапазон, в котором пример получает низкие потери. Это также видно из рис. 3. Теперь мы воспользуемся несбалансированным набором данных реального мира и посмотрим, как работают фокусные потери.

Мошенничество с кредитными картами: набор данных о несбалансированности классов:

Описание набора данных: Здесь я рассмотрел чрезвычайно несбалансированный по классам набор данных, доступный в Kaggle, и набор данных содержит транзакции, совершенные с помощью кредитных карт в сентябре 2013 года европейскими держателями карт. Давайте использовать панд -

Этот набор данных представляет транзакции, которые произошли за два дня, и у нас есть 284 807 транзакций. Функции V1, V2,… V28 являются основными компонентами, полученными с помощью PCA (исходные функции не предоставляются из-за проблем с конфиденциальностью), и единственными функциями, которые не были преобразованы с помощью PCA, являются «Время» и «Сумма». Функция «Время» содержит секунды, прошедшие между каждой транзакцией и первой транзакцией в наборе данных, а функция «Сумма» - это сумма транзакции. Функция «Класс» - это переменная ответа, которая принимает значение 1 в случае мошенничества и 0 в противном случае.

Несбалансированность классов. Давайте изобразим распределение признака «Класс», который сообщает нам, сколько транзакций являются настоящими и поддельными. Как показано на рисунке 4 выше, подавляющее количество транзакций является реальным. Давайте получим числа с помощью этого простого фрагмента кода -

print (‘real cases:‘, len(credit_df[credit_df[‘Class’]==0]))
print (‘fraud cases: ‘, len(credit_df[credit_df[‘Class’]==1]))

>>> real cases:  284315
    fraud cases:  492

Таким образом, коэффициент дисбаланса классов составляет примерно 1: 578, так что для 578 реальных транзакций у нас есть один случай мошенничества. Сначала давайте воспользуемся простой нейронной сетью с кросс-энтропийной потерей для прогнозирования мошенничества и реальных транзакций. Но перед этим небольшое исследование показывает, что функции Количество и Время не масштабируются, тогда как другие функции V1, V2… и т. Д. Масштабируются. Здесь мы можем использовать StandardScaler / RobustScaler для масштабирования этих функций, и, поскольку RobustScaler устойчив к выбросам, я выбрал этот метод стандартизации.

Теперь давайте выберем функции и метку, как показано ниже -

X_labels = credit_df.drop([‘Class’], axis=1)
y_labels = credit_df[‘Class’]
X_labels = X_labels.to_numpy(dtype=np.float64)
y_labels = y_labels.to_numpy(dtype=np.float64)

y_lab_cat = tf.keras.utils.to_categorical(y_labels, num_classes=2, dtype=’float32')

Для разделения поезда и теста мы используем стратификацию, чтобы сохранить соотношение меток -

x_train, x_test, y_train, y_test = train_test_split(X_labels, y_lab_cat, test_size=0.3, stratify=y_lab_cat, shuffle=True)

Теперь мы построим простую модель нейронной сети с 3 плотными слоями -

def simple_model():
   input_data = Input(shape=(x_train.shape[1], ))
   x = Dense(64)(input_data)
   x = Activation(activations.relu)(x)
   x = Dense(32)(x)
   x = Activation(activations.relu)(x)
   x = Dense(2)(x)
   x = Activation(activations.softmax)(x)
   model = Model(inputs=input_data, outputs=x, name=’Simple_Model’)
   return model

Скомпилируйте модель с категориальной кросс-энтропией как потерями -

simple_model.compile(optimizer=Adam(learning_rate=5e-3), loss='categorical_crossentropy', metrics=['acc'])

Обучите модель -

simple_model.fit(x_train, y_train, validation_split=0.2, epochs=5, shuffle=True, batch_size=256)

Чтобы по-настоящему понять производительность модели, нам нужно построить матрицу неточностей вместе с оценками точности, отзыва и F1 -

Из матрицы путаницы и других показателей производительности мы видим, что, как и ожидалось, сеть отлично справляется с классификацией реальных транзакций, но значение отзыва ниже 50% для класса мошенничества. Наша цель - протестировать, ничего не меняя, кроме функции потерь, можем ли мы получить лучшие значения для показателей производительности?

Использование Focal Loss:

Сначала давайте определим фокусные потери с альфа и гаммой как гиперпараметры, и для этого я использовал tfa module, который представляет собой функциональность для TensorFlow, поддерживаемую SIG-аддонами (tfa). В этом модуле среди дополнительных потерь есть реализация Focal Loss, и сначала мы импортируем, как показано ниже -

import tensorflow_addons as tfa
fl = tfa.losses.SigmoidFocalCrossEntropy(alpha, gamma)

Используя это, давайте определим пользовательскую функцию потерь, которая может использоваться в качестве прокси для «Focal Loss» для этой конкретной проблемы с двумя классами:

def focal_loss_custom(alpha, gamma):
   def binary_focal_loss(y_true, y_pred):
      fl = tfa.losses.SigmoidFocalCrossEntropy(alpha=alpha, gamma=gamma)
      y_true_K = K.ones_like(y_true)
      focal_loss = fl(y_true, y_pred)
      return focal_loss
   return binary_focal_loss

Теперь мы просто повторяем описанные выше шаги для определения модели, компиляции и подгонки, но на этот раз с использованием фокальных потерь, как показано ниже -

simple_model.compile(optimizer=Adam(learning_rate=5e-3),       loss=focal_loss_custom(alpha=0.2, gamma=2.0), metrics=[‘acc’])

Для параметров альфа и гамма я только что использовал значения, предложенные в документе (однако проблема в другом), и необходимо проверить другие значения.

simple_model.fit(x_train, y_train, validation_split=0.2, epochs=5, shuffle=True, batch_size=256)

Используя Focal Loss, мы видим улучшение, как показано ниже -

Мы видим, что с использованием «Focal Loss» показатели производительности значительно улучшились, и мы смогли правильно обнаружить больше «мошеннических» транзакций (101/148) по сравнению с предыдущим случаем (69/148).

Здесь, в этом посте, мы обсуждаем фокусные потери и то, как они могут улучшить задачу классификации, когда данные сильно несбалансированы. Чтобы продемонстрировать Focal Loss в действии, мы использовали набор данных о транзакциях с кредитными картами, который сильно смещен в сторону реальных транзакций, и показал, как Focal Loss улучшает эффективность классификации.

Я также хотел бы упомянуть, что в моем исследовании с данными гамма-лучей мы пытаемся классифицировать активные галактические ядра (AGN) от пульсаров (PSR), а гамма-небо в основном населенными AGN. На картинке ниже показан пример такого смоделированного неба. Это также пример несбалансированного набора данных в компьютерном зрении.

Использованная литература:

[1] Исходная бумага с потерей фокуса

[2] Оригинальная презентация с потерей фокуса

[3] Блокнот, использованный в этом сообщении: GitHub