Реализация RepVGG в PyTorch

Сделайте свою CNN ›в 100 раз быстрее

Привет!! Сегодня мы увидим, как реализовать RepVGG в PyTorch, предложенный в RepVGG: снова сделать ConvNets в стиле VGG великими.

Код находится здесь, интерактивную версию этой статьи можно скачать здесь.

Давайте начнем!

В документе предложена новая архитектура, которую можно настроить после обучения, чтобы она работала быстрее на современном оборудовании. И под более быстрым я подразумеваю быстрое освещение, эта идея была использована в модели Apple MobileOne.

Модели с одной и несколькими ветвями

Многие недавние модели используют мультиветвь, когда входные данные проходят через разные уровни, а затем каким-то образом агрегируются (обычно с добавлением).

Это здорово, потому что делает многоотраслевую модель неявным ансамблем многочисленных более мелких моделей. В частности, модель можно интерпретировать как ансамбль из 2^n моделей, поскольку каждый блок разветвляет поток на два пути.

К сожалению, модели с несколькими ветвями потребляют больше памяти и работают медленнее, чем модели с одной ветвью. Давайте создадим классический ResNetBlock, чтобы понять, почему (посмотрите мою статью о ResNet в PyTorch).

Хранение residual двойного потребления памяти. Это также показано на следующем изображении из статьи

Авторы заметили, что многоветвевая архитектура полезна только во время обучения. Таким образом, если у нас будет способ удалить его во время тестирования, мы сможем улучшить скорость модели и потребление памяти.

От нескольких филиалов к одному филиалу

Рассмотрим следующую ситуацию: у вас есть две ветки, состоящие из двух 3x3 конв.

torch.Size([1, 8, 5, 5])

Теперь мы можем создать одну конверсию, назовем ее conv_fused, чтобы conv_fused(x) = conv1(x) + conv2(x). Очень легко, мы можем просто суммировать weights и bias двух конв! Таким образом, нам нужно запустить только один conv вместо двух.

Посмотрим, насколько это быстрее!

conv1(x) + conv2(x) tooks 0.000421s
conv_fused(x) tooks 0.000215s

Почти на 50% меньше (имейте в виду, что это очень наивный тест, позже мы увидим лучший результат)

Fuse Conv и BatchNorm

В современных сетевых архитектурах BatchNorm используется в качестве уровня регуляризации после блока свертки. Мы можем захотеть объединить их вместе, поэтому создайте conv так, чтобы conv_fused(x) = batchnorm(conv(x)). Идея состоит в том, чтобы изменить веса conv, чтобы включить смещение и масштабирование из BatchNorm.

Бумага объясняет это следующим образом:

Код следующий:

Посмотрим, работает ли это

да, мы объединили слои Conv2d и BatchNorm2d. Также есть статья от PyTorch об этом

Итак, наша цель — объединить все ветки в одну конверсию, сделав сеть быстрее!

Автор предложил новый тип блока, названный RepVGG. Подобно ResNet, у него есть ярлык, но также есть соединение с идентификацией (или лучше ответвление).

В ПиТорч:

Репараметризация

У нас есть один 3x3 conv->bn, один 1x1 conv-bn и (иногда) один batchnorm (ветвь идентификации). Мы хотим объединить их вместе, чтобы создать одно единственное conv_fused, такое что conv_fused = 3x3conv-bn(x) + 1x1conv-bn(x) + bn(x) или, если у нас нет идентификационной связи, conv_fused = 3x3conv-bn(x) + 1x1conv-bn(x).

Пойдем шаг за шагом. Чтобы создать conv_fused, нам нужно:

  • объединить 3x3conv-bn(x) в один 3x3conv
  • 1x1conv-bn(x), затем преобразуйте его в 3x3conv
  • преобразовать идентификатор bn в 3x3conv
  • добавить все три 3x3convs

Подытожено изображением ниже:

Первый шаг прост, мы можем использовать get_fused_bn_to_conv_state_dict на RepVGGBlock.block (основной 3x3 conv-bn).

Второй шаг аналогичен, get_fused_bn_to_conv_state_dict на RepVGGBlock.shortcut (1x1 conv-bn). Затем мы дополняем каждое ядро ​​слитого 1x1 1 в каждом измерении, создавая 3x3.

Личность bn сложнее. Нам нужно создать 3x3 conv, который будет действовать как функция идентичности, а затем использовать get_fused_bn_to_conv_state_dict, чтобы объединить его с идентичностью bn. Это можно сделать, имея 1 в центре соответствующего ядра для этого соответствующего канала.

Напомним, что вес конв — это тензор in_channels, out_channels, kernel_h, kernel_w. Если мы хотим создать идентификатор, такой как conv(x) = x, нам нужно иметь один единственный 1 для этого канала.

Например:

https://gist.github.com/c499d53431d243e9fc811f394f95aa05

torch.Size([2, 2, 3, 3])
Parameter containing:
tensor([[[[0., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.]],
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],

        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
         [[0., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.]]]], requires_grad=True)

Видите ли, мы создали Conv, которая действует как функция идентификации.

Теперь, собрав все воедино, этот шаг формально называется репараметризацией.

Наконец, давайте определим RepVGGFastBlock. Это только составлено conv + relu

и добавьте метод to_fast к RepVGGBlock, чтобы быстро создать правильный RepVGGFastBlock

РепВГГ

Давайте определим RepVGGStage (набор блоков) и RepVGG с помощью удобного метода switch_to_fast, который будет переключаться на быстрый блок на месте:

Давайте проверим это!

Я создал бенчмарк внутри benchmark.py, запустив модель на моей gtx 1080ti с разными размерами пакетов, и вот результат:

Модель имеет два слоя на этап, четыре этапа и ширину 64, 128, 256, 512.

В своей статье они масштабируют эти значения на некоторую величину (называемую a и b) и используют сгруппированные конверсии. Поскольку нас больше интересует часть репараметризации, я их пропускаю.

Да, так что в основном модель репараметризации находится в другом масштабе времени по сравнению с ванильной. Ух ты!

Позвольте мне скопировать и вставить фрейм данных, который я использовал для хранения теста

Вы можете видеть, что модель по умолчанию (многоветвевая) занимает 1.45с для batch_size=128, в то время как параметризованная (быстрая) занимает только 0.0134с. Это 108x 🚀🚀🚀

.

Выводы

Выводы В этой статье мы шаг за шагом рассмотрели, как создать RepVGG; молниеносно быстрая модель, использующая умную технику репараметризации.

Этот метод может быть перенесен и на другую архитектуру.

Спасибо, что прочитали это!

👉 Реализация SegFormer в PyTorch

👉 Реализация ConvNext в PyTorch

Франческо