Это краткое изложение недавней (и очень удобочитаемой) статьи, которая мне понравилась, под названием MixMatch: целостный подход к полу-контролируемому обучению, в которой представлен новый алгоритм для обучения моделей на небольшом количестве помеченных данных и большом количестве немаркированных данных, что намного точнее, чем другие подходы. Я обильно перефразирую авторов статьи, что, на мой взгляд, нормально - резюмирую их работу. Что касается кода, я начал с этого репо и внес изменения, чтобы приблизить его к работе. Репо - рабочий продукт. По состоянию на 18 мая у него есть работающие функции DataLoader и Loss, но я еще не проводил эксперименты или тщательно проверял правильность. Авторский код тензорного потока находится здесь.

Мотивация

Полу-контролируемое обучение (SSL) основано на использовании информации из немаркированных данных для обучения модели. Большинство статей, на которые ссылается MixMatch, как правило, используют пакеты как помеченных, так и немаркированных данных в одном обучающем прогоне, в отличие от предварительного обучения на задаче языкового моделирования, а затем точной настройки на парадигме помеченных данных ULMFit, Elmo и BERT. Недавняя работа пытается научиться добавлять термин к функции потерь, который поощряет:

(а) уверенные прогнозы по немаркированным данным. Даже если мы не знаем ярлык для примера, мы знаем, что он вряд ли будет наполовину в одном классе и наполовину в другом. Псевдо-маркировка, которая берет самые надежные прогнозы модели для немаркированных данных и предполагает, что наибольшая прогнозируемая вероятность является правильной меткой, является одной из форм этого.

(b) регуляризация согласованности: поощряйте одинаковые прогнозы для слегка измененных входных данных. Прогнозы не должны сильно измениться, если исходные данные меняются незначительно. Это идея другой работы Гудфеллоу «Виртуальное состязательное обучение».

MixMatch также использует Mixup, (TODO: ссылка) метод, при котором мы обучаем модель на комбинациях примеров. Например, мы скармливаем изображению наполовину кошку и наполовину собаку (либо накладывая изображения, либо помещая их рядом друг с другом на изображении) и ожидаем, что модель выдаст p (собака) =. 5, p (кошка) = .5 и вероятность 0 для остальных классов.

Ни один из этих методов не является взаимоисключающим, и волшебство бумаги MixMatch состоит в том, что они находят разумный способ использовать их все!

MixMatch

Сложность: потери рассчитываются отдельно для помеченных и немаркированных примеров, но они потенциально могут быть перепутаны!

(1) Обозначьте угадывание для немаркированных примеров:

  • Увеличьте каждый немаркированный пример k = 2 раза, усредните прогнозы каждого расширенного примера
  • Увеличьте прогнозируемые вероятности классов с помощью функции повышения резкости. Полученное распределение является нашим предположением для меток в этом примере.
def sharpen(preds: np.NDArray, temp=.5) -> np.NDArray:
    # as temp goes to 0 this starts to look more like a one-hot distribution
    numerators = preds ** (1/temp)
    return numerators / (numerators.sum(axis=1))
    
ub = [augment_fn(X_unlabeled) for _ in range(K)]  # K augmentations
qb = sharpen(sum(map(model, ub)) / K, T)  # labeled guesses

(2) Смешивание и совпадение

Мы начинаем с трех пакетов: X набор помеченных примеров, U набор немаркированных примеров и предполагаемых ярлыков, и W=Shuffle(concat([X, U]). Мы используем W, чтобы получить партнеров по смешиванию для каждого примера. Вот код:

def lin_comb(a, b, frac_a): return (frac_a * a) + (1 - frac_a) * b
C = np.concatenate
Ux = C(ub, axis=0)
Uy = C([qb for _ in range(K)], axis=0)
indices = shuffle(np.arange(len(xb) + len(Ux)))
Wx = C([Ux, xb], axis=0)[indices]  # mixup partners
Wy = C([qb, y], axis=0)[indices] # mixup partners
  • Примените смешение к помеченным данным и записям из Ux and W [: len (X)], получив X’mix-up примеров, где по крайней мере 50% примера взято из помеченной точки данных.
def mixup(x1, x2, y1, y2, alpha):
    """Modified to make sure that majority of example comes from x1."""                  beta = np.random.beta(alpha, -alpha, x1.shape[0])
    beta = np.maximum(beta, 1 - beta)  
    return lin_comb(x1, x2, beta), lin_comb(y1, y2, beta)
X, p = mixup(xb, Wx[:n_labeled], y, Wy[:n_labeled], alpha)  # Labeled batch
U, q = mixup(Ux, Wx[n_labeled:], Uy, Wy[n_labeled:], alpha)  # Unlabeled batch
return  C([X, U], axis=1), C([p, q], axis=1), n_labeled # need to keep track of where labeled examples are for loss fn
  • Сделайте помеченную партию, содержащую помеченные примеры, смешанные с потенциально немаркированными примерами. Убедитесь, что помеченный пример должен иметь более 50% веса при усреднении этикеток и изображений.
  • Сделайте немаркированный пакет, получая смешанные примеры, где по крайней мере 50% примера взяты из немаркированной точки данных.

(3) Потери - это взвешенная комбинация квадратов потерь L2 (MSE) для в основном немаркированных U', CrossEntropy для X'

Здесь мой код (и понимание) ломается: цель, связанная с примером в X', перепутана, но CrossEntropyLoss ожидает постоянных целей.

Примечание: автор упоминает, что «мы не распространяем градиенты через предполагаемые метки». Что это обозначает? Зачем вообще создавать U', если это не влияет на потери? Я интерпретирую то, что вы относитесь к предполагаемым меткам как к реальным меткам и просто поддерживаете потерю через ту часть вычислительного графа, которая сделала предсказания. Поэтому, если вы реализуете MixMatch для тензоров, вам нужно убедиться, что автоградиция отключена, прежде чем делать qb.

Результаты

Эксперименты проводятся с использованием широкой модели Resnet на cifar-10 и SHVN, при этом моделям предлагается различное количество этикеток. На рисунке ниже мы видим, что MixMatch почти достигает эталонного теста «полностью контролируемого» (обучение на 5000 помеченных изображениях, со значительно меньшим количеством помеченных примеров и значительно превосходит другие методы SSL:

Также кажется, что ни один из методов, которые использует MixMatch, не является лишним, попытка удалить любой шаг значительно ухудшает показатели, а немаркированные данные, очевидно, помогают!

Дальнейшие действия:

Попробуйте воспроизвести производительность cifar-10 с помощью WideResnet.