В предыдущих статьях мы видели различные методы переноса нейронного стиля, сначала с определенным стилем и изображениями контента, затем с определенным стилем и произвольными изображениями контента. В этой статье мы будем реализовывать исследовательскую работу «Передача произвольного стиля в режиме реального времени с адаптивной нормализацией экземпляра», в которой описывается метод, выполняющий передачу стиля с изображениями произвольного стиля и контента, что означает, что переобучение не требуется для изображений разных стилей. .

Настройка обучения

Основным компонентом этого подхода является слой Adaptive Instance Normalization (AdaIN). Учебную установку можно увидеть на следующей диаграмме из исследовательской работы.

  1. Мы пропускаем изображения стиля и контента через кодировщик VGG для извлечения функций.
  2. Мы передаем функции стиля и контента через слой AdaIN, который перенастраивает функции контента, чтобы они имели то же среднее значение и стандартное отклонение, что и функции стиля.
  3. Затем мы передаем выровненные элементы содержимого обратно через декодер для создания стилизованного изображения.
  4. Извлеченные функции из кодировщика VGG также используются для вычисления потерь контента.

Набор данных

Этот подход включает в себя набор данных изображений контента, а также набор данных изображений стилей, поскольку нам нужно иметь возможность использовать как контент artbitrary, так и изображения произвольного стиля.

Изображения контента взяты из набора данных MS COCO, который использовался в предыдущих статьях. Набор стилевых изображений состоит из картин, взятых с WikiArt, и их можно найти здесь.

Предпринимаются следующие шаги предварительной обработки.

  1. Изображения передаются через модель парами: одно изображение содержимого и одно изображение стиля.
  2. Меньший размер каждого изображения изменяется до 512 с сохранением соотношения сторон.
  3. Случайная обрезка 256 на 256 пикселей берется как из содержимого, так и из изображений стиля. Это используется для обучения.

Код для этого можно найти ниже и в файле dataset.py в репозитории.

class AdaINDataset:

    def __init__(self, content_path, style_path, batch_size) -> None:

        self.T = transforms.Compose([
            transforms.Resize(512),
            transforms.RandomCrop((256, 256), padding=(20, 20)),
            transforms.ToTensor(),
        ])

        self.content_folder = ImageFolder(content_path, transform=self.T)
        self.style_folder = ImageFolder(style_path, transform=self.T)

        self.content_loader = DataLoader(self.content_folder, batch_size, shuffle=True)
        self.style_loader = DataLoader(self.style_folder, batch_size, shuffle=True)

Кодировщик VGG

При таком подходе к передаче стиля у кодировщика есть две основные функции.

  1. Извлеките функции для использования в слое AdaIN. Для этого шага необходимы объекты из слоя relu4_1.
  2. Извлеките признаки для вычисления потерь. Для этого потребуются признаки из слоев relu1_1, relu2_1, relu3_1 и relu4_1.

Поэтому мы создадим кодировщик для извлечения признаков из этих 4 слоев. Следующий код можно найти в файле model.py в репозитории.

import torch
from torch import nn

# vgg encoder model
class VGGEncoder(nn.Module):

    def __init__(self, weight_path=None) -> None:
        super(VGGEncoder, self).__init__()
        # layers to extract features from
        self.feature_layers = [3, 10, 17, 30]

        # creating model and adding first layer
        self.model = nn.ModuleList()
        self.model.append(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(1, 1)))
        self.model.append(nn.ReflectionPad2d((1, 1, 1, 1)))

        # parameters for remaining layers
        # (in_channels, out_channels, kernel_size)
        params = [
            (3, 64, (3, 3)),
            (64, 64, (3, 3)),
            (64, 128, (3, 3)),
            (128, 128, (3, 3)),
            (128, 256, (3, 3)),
            (256, 256, (3, 3)),
            (256, 256, (3, 3)),
            (256, 256, (3, 3)),
            (256, 512, (3, 3)),
            (512, 512, (3, 3)),
            (512, 512, (3, 3)),
            (512, 512, (3, 3)),
            (512, 512, (3, 3)),
            (512, 512, (3, 3)),
            (512, 512, (3, 3)),
            (512, 512, (3, 3))
        ]

        # adding layers to model
        # also adding a maxpool layer whenever number of out_channels changes
        for i, param in enumerate(params[:-1]):
            self.model.append(nn.Conv2d(in_channels=param[0], out_channels=param[1], kernel_size=param[2]))
            self.model.append(nn.ReLU())
            if params[i + 1][0] != params[i + 1][1]:
                self.model.append(nn.MaxPool2d(kernel_size=(2,2), stride=(2,2), padding=(0,0), ceil_mode=True))
            self.model.append(nn.ReflectionPad2d((1, 1, 1, 1)))

        # inserting extra maxpool layer based on vgg architecture
        self.model.insert(40, nn.MaxPool2d(kernel_size=(2,2), stride=(2,2), padding=(0,0), ceil_mode=True))

        # adding last layer
        self.model.append(nn.Conv2d(in_channels=params[-1][0], out_channels=params[-1][1], kernel_size=params[-1][2]))
        self.model.append(nn.ReLU())

        # loading pretrained model weights if path is provided
        if weight_path:
            self.model.load_state_dict(torch.load(weight_path, map_location="cpu"))
            
        # encoder is fully pretrained, so no weights need to be adjusted
        # gradients need not be accumulated
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        # list to store activations from layers in self.feature_layers
        activations = []
        for i, layer in enumerate(self.model[:31]):
            x = layer(x)
            if i in self.feature_layers:
                activations.append(x)
        return activations

Декодер — это просто перевернутая версия кодера, заменяющая объединяющие слои слоями повышающей дискретизации.

# vgg decoder model
class VGGDecoder(nn.Module):

    def __init__(self, encoder, weight_path=None) -> None:
        super(VGGDecoder, self).__init__()
        self.decoder = nn.ModuleList()
        # reflection of encoder, so we iterate in reverse order
        for layer in encoder.model[:31][::-1][:-2]:
            # switch location of reflection pad and relu
            if isinstance(layer, nn.ReflectionPad2d):
                self.decoder.append(nn.ReLU())
            elif isinstance(layer, nn.ReLU):
                self.decoder.append(nn.ReflectionPad2d((1, 1, 1, 1)))
            elif isinstance(layer, nn.MaxPool2d):
                self.decoder.append(nn.Upsample(scale_factor=2, mode='nearest'))
            elif isinstance(layer, nn.Conv2d):
                layer = nn.Conv2d(
                    in_channels=layer.out_channels,
                    out_channels=layer.in_channels,
                    kernel_size=layer.kernel_size
                )
                self.decoder.append(layer)

        if weight_path:
            self.decoder.load_state_dict(torch.load(weight_path, map_location="cpu"))
            
    def forward(self, x):
        for layer in self.decoder:
            x = layer(x)
        return x

Адаптивная нормализация экземпляра (AdaIN)

Теперь о слое AdaIN между кодером и декодером, который составляет основу этого метода передачи стиля.

Формулу для AdaIN можно увидеть ниже, где x представляет функции контента, а y представляет функции стиля;

Шаги следующие;

  1. Вычислите среднее значение по каналам и стандартное отклонение как для содержания, так и для стилевых характеристик.
  2. Нормализуйте функции контента, чтобы иметь 0 среднего и 1 стандартное отклонение.
  3. Перестроены нормализованные функции контента, чтобы они имели то же среднее значение и стандартное отклонение, что и функции стиля.

Это выглядит следующим образом в коде;

def AdaIN_realign(style, content):
    # flatten images while retaining batchs and channels, compute mean and std
    B, C = content.shape[0], content.shape[1]
    content_mean = content.view(B, C, -1).mean(dim=2)
    content_std = content.view(B, C, -1).std(dim=2)
    style_mean = style.view(B, C, -1).mean(dim=2)
    style_std = style.view(B, C, -1).std(dim=2)

    # reshape mean and std to perform normalization and realignment
    content_mean, content_std = content_mean.view(B, C, 1, 1), content_std.view(B, C, 1, 1)
    style_mean, style_std = style_mean.view(B, C, 1, 1), style_std.view(B, C, 1, 1)

    content_mean, content_std = content_mean.expand(content.size()), content_std.expand(content.size())
    style_mean, style_std = style_mean.expand(style.size()), style_std.expand(style.size())

    # normalize content features. Small constant added to avoid zero division.
    normalized = (content - content_mean) / (content_std + 0.00001)
    # realign normalized content features with style mean and std
    realigned = (normalized * style_std) + style_mean
    return realigned

Функции потерь

Как и в случае с другими подходами к передаче нейронного стиля, функция потерь состоит из двух частей; потеря содержания и потеря стиля.

Потеря контента

Потеря контента — это просто среднеквадратичная потеря ошибки между элементами контента с повторным выравниванием (выходные данные слоя AdaIN перед прохождением через декодер) и извлеченными функциями из стилизованного изображения, как показано на диаграмме обучения. установка выше.

Потеря стиля

Следующие шаги описывают расчет потери стиля.

  1. Изображение стиля и сгенерированное изображение передаются через кодировщик для извлечения признаков.
  2. Для каждого рассчитывается среднее значение по каналу и стандартное отклонение.
  3. Среднеквадратическая ошибка между средним значением и стандартным отклонением берется и суммируется. Это потеря стиля.

Код для потери содержимого и стиля можно увидеть ниже. Это в файле loss.py в репозитории.

from torch import nn

class AdaINLoss:

    def __init__(self, enc, style_weight) -> None:
        self.mse = nn.MSELoss()
        self.loss_network = enc
        self.style_weight = style_weight

    # a simple mse for the content loss
    def content_loss(self, realigned_content, pred_feature_last):
        return self.mse(realigned_content, pred_feature_last)

    # mse of the channel-wise mean and std for style loss
    def style_loss(self, style_features, pred_features):
        style_loss = 0
        for s_ft, p_ft in zip(style_features, pred_features):
            B, C = s_ft.shape[0], s_ft.shape[1]
            s_ft, p_ft = s_ft.view(B, C, -1), p_ft.view(B, C, -1)
            mean_loss = self.mse(s_ft.mean(dim=2), p_ft.mean(dim=2))
            std_loss = self.mse(s_ft.std(dim=2), p_ft.std(dim=2))
            style_loss += (mean_loss + std_loss)
        return style_loss


    def calculate_loss(self, style_features, pred_img, realigned_content):
        pred_features = self.loss_network(pred_img)
        content_loss = self.content_loss(realigned_content, pred_features[-1])
        style_loss = self.style_loss(style_features, pred_features)
        return content_loss + (style_loss * self.style_weight)

Тренировочный цикл

Цикл обучения аналогичен предыдущим статьям о переносе нейронного стиля. Код можно увидеть ниже и в файле train.py в репозитории.

from model import VGGEncoder, VGGDecoder, AdaIN_realign
from loss import AdaINLoss
from dataset import AdaINDataset

from torch import optim
import torch
from torchvision import transforms
from tqdm import tqdm
from PIL import Image


class TrainAdaIN:

    def __init__(self, epochs, style_weight, lr, batch_size, content_path, style_path,
                test_content, test_style, dev, enc_weight_path=None, dec_weight_path=None, show_test_output=False) -> None:
        self.dev = dev
        self.style_weight = style_weight
        self.batch_size = batch_size
        self.lr = lr
        self.epochs = epochs
        self.show_test_output = show_test_output

        self.load_weights(enc_weight_path, dec_weight_path)

        self.adain_loss = AdaINLoss(self.enc, self.style_weight)

        self.dataset = AdaINDataset(content_path, style_path, batch_size)
        self.test_images_init(test_content, test_style)


    # load pretrained weights
    def load_weights(self, enc_weight_path, dec_weight_path):
        if enc_weight_path:
            self.enc = VGGEncoder(weight_path=enc_weight_path).to(self.dev)
        else:
            self.enc = VGGEncoder().to(self.dev)

        if dec_weight_path:
            self.dec = VGGDecoder(self.enc, weight_path=dec_weight_path).to(self.dev)
        else:
            self.dec = VGGDecoder(encoder=self.enc).to(self.dev)


    # initialize test images to show training progree
    def test_images_init(self, content_path, style_path):
        self.content_test = Image.open(content_path)
        self.style_test = Image.open(style_path).resize(self.content_test.size)

        T = transforms.Compose([
            transforms.Resize(512),
            transforms.ToTensor()
        ])
        self.to_pil = transforms.ToPILImage()

        self.content_test, self.style_test = T(self.content_test), T(self.style_test)
        self.content_test, self.style_test = self.content_test.unsqueeze(0), self.style_test.unsqueeze(0)
        self.content_test, self.style_test = self.content_test.to(self.dev), self.style_test.to(self.dev)

        self.content_test = self.enc(self.content_test)
        self.style_test = self.enc(self.style_test)
        self.realigned_content_test = AdaIN_realign(self.style_test[-1], self.content_test[-1])


    # save weights every few iterations
    def save_checkpoint(self, save_path = "weights/dec.pth"):
        torch.save(self.dec.state_dict(), save_path)

    # show test outputs every few iterations
    def show_output(self):
        test_pred_img = self.dec(self.realigned_content_test)[0].clip(0, 1)
        test_pred_img = self.to_pil(test_pred_img)
        test_pred_img.save("output.jpg")


    def train(self):
        opt = optim.Adam(self.dec.parameters(), lr=self.lr)

        for e in range(1, self.epochs + 1):
            loop = tqdm(
                enumerate(zip(self.dataset.content_loader, self.dataset.style_loader)), 
                total=len(self.dataset.style_loader), 
                leave=False, 
                position=0
            )
            loop.set_description(f"Epoch - {e} | ")
            for i, ((content, _), (style, _)) in loop:
                content, style = content.to(self.dev), style.to(self.dev)
                opt.zero_grad()
                self.dec.train()

                content_features = self.enc(content)
                style_features = self.enc(style)
                realigned_content = AdaIN_realign(style_features[-1], content_features[-1])
                pred_img = self.dec(realigned_content).clip(0,1)

                loss = self.adain_loss.calculate_loss(style_features, pred_img, realigned_content)
                loss.backward()
                opt.step()

                loop.set_postfix(loss=loss.item())

                if i % 10 == 0:
                    self.save_checkpoint()
                    if self.show_test_output:
                        self.dec.eval()
                        self.show_output()

Полученные результаты

Если вы хотите сэкономить время на тренировке, предварительно тренированные веса можно найти здесь. Вот пример нескольких стилизованных изображений.