В предыдущих статьях мы видели различные методы переноса нейронного стиля, сначала с определенным стилем и изображениями контента, затем с определенным стилем и произвольными изображениями контента. В этой статье мы будем реализовывать исследовательскую работу «Передача произвольного стиля в режиме реального времени с адаптивной нормализацией экземпляра», в которой описывается метод, выполняющий передачу стиля с изображениями произвольного стиля и контента, что означает, что переобучение не требуется для изображений разных стилей. .
Настройка обучения
Основным компонентом этого подхода является слой Adaptive Instance Normalization (AdaIN). Учебную установку можно увидеть на следующей диаграмме из исследовательской работы.
- Мы пропускаем изображения стиля и контента через кодировщик VGG для извлечения функций.
- Мы передаем функции стиля и контента через слой AdaIN, который перенастраивает функции контента, чтобы они имели то же среднее значение и стандартное отклонение, что и функции стиля.
- Затем мы передаем выровненные элементы содержимого обратно через декодер для создания стилизованного изображения.
- Извлеченные функции из кодировщика VGG также используются для вычисления потерь контента.
Набор данных
Этот подход включает в себя набор данных изображений контента, а также набор данных изображений стилей, поскольку нам нужно иметь возможность использовать как контент artbitrary, так и изображения произвольного стиля.
Изображения контента взяты из набора данных MS COCO, который использовался в предыдущих статьях. Набор стилевых изображений состоит из картин, взятых с WikiArt, и их можно найти здесь.
Предпринимаются следующие шаги предварительной обработки.
- Изображения передаются через модель парами: одно изображение содержимого и одно изображение стиля.
- Меньший размер каждого изображения изменяется до 512 с сохранением соотношения сторон.
- Случайная обрезка 256 на 256 пикселей берется как из содержимого, так и из изображений стиля. Это используется для обучения.
Код для этого можно найти ниже и в файле dataset.py в репозитории.
class AdaINDataset: def __init__(self, content_path, style_path, batch_size) -> None: self.T = transforms.Compose([ transforms.Resize(512), transforms.RandomCrop((256, 256), padding=(20, 20)), transforms.ToTensor(), ]) self.content_folder = ImageFolder(content_path, transform=self.T) self.style_folder = ImageFolder(style_path, transform=self.T) self.content_loader = DataLoader(self.content_folder, batch_size, shuffle=True) self.style_loader = DataLoader(self.style_folder, batch_size, shuffle=True)
Кодировщик VGG
При таком подходе к передаче стиля у кодировщика есть две основные функции.
- Извлеките функции для использования в слое AdaIN. Для этого шага необходимы объекты из слоя relu4_1.
- Извлеките признаки для вычисления потерь. Для этого потребуются признаки из слоев relu1_1, relu2_1, relu3_1 и relu4_1.
Поэтому мы создадим кодировщик для извлечения признаков из этих 4 слоев. Следующий код можно найти в файле model.py в репозитории.
import torch from torch import nn # vgg encoder model class VGGEncoder(nn.Module): def __init__(self, weight_path=None) -> None: super(VGGEncoder, self).__init__() # layers to extract features from self.feature_layers = [3, 10, 17, 30] # creating model and adding first layer self.model = nn.ModuleList() self.model.append(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(1, 1))) self.model.append(nn.ReflectionPad2d((1, 1, 1, 1))) # parameters for remaining layers # (in_channels, out_channels, kernel_size) params = [ (3, 64, (3, 3)), (64, 64, (3, 3)), (64, 128, (3, 3)), (128, 128, (3, 3)), (128, 256, (3, 3)), (256, 256, (3, 3)), (256, 256, (3, 3)), (256, 256, (3, 3)), (256, 512, (3, 3)), (512, 512, (3, 3)), (512, 512, (3, 3)), (512, 512, (3, 3)), (512, 512, (3, 3)), (512, 512, (3, 3)), (512, 512, (3, 3)), (512, 512, (3, 3)) ] # adding layers to model # also adding a maxpool layer whenever number of out_channels changes for i, param in enumerate(params[:-1]): self.model.append(nn.Conv2d(in_channels=param[0], out_channels=param[1], kernel_size=param[2])) self.model.append(nn.ReLU()) if params[i + 1][0] != params[i + 1][1]: self.model.append(nn.MaxPool2d(kernel_size=(2,2), stride=(2,2), padding=(0,0), ceil_mode=True)) self.model.append(nn.ReflectionPad2d((1, 1, 1, 1))) # inserting extra maxpool layer based on vgg architecture self.model.insert(40, nn.MaxPool2d(kernel_size=(2,2), stride=(2,2), padding=(0,0), ceil_mode=True)) # adding last layer self.model.append(nn.Conv2d(in_channels=params[-1][0], out_channels=params[-1][1], kernel_size=params[-1][2])) self.model.append(nn.ReLU()) # loading pretrained model weights if path is provided if weight_path: self.model.load_state_dict(torch.load(weight_path, map_location="cpu")) # encoder is fully pretrained, so no weights need to be adjusted # gradients need not be accumulated for param in self.model.parameters(): param.requires_grad = False def forward(self, x): # list to store activations from layers in self.feature_layers activations = [] for i, layer in enumerate(self.model[:31]): x = layer(x) if i in self.feature_layers: activations.append(x) return activations
Декодер — это просто перевернутая версия кодера, заменяющая объединяющие слои слоями повышающей дискретизации.
# vgg decoder model class VGGDecoder(nn.Module): def __init__(self, encoder, weight_path=None) -> None: super(VGGDecoder, self).__init__() self.decoder = nn.ModuleList() # reflection of encoder, so we iterate in reverse order for layer in encoder.model[:31][::-1][:-2]: # switch location of reflection pad and relu if isinstance(layer, nn.ReflectionPad2d): self.decoder.append(nn.ReLU()) elif isinstance(layer, nn.ReLU): self.decoder.append(nn.ReflectionPad2d((1, 1, 1, 1))) elif isinstance(layer, nn.MaxPool2d): self.decoder.append(nn.Upsample(scale_factor=2, mode='nearest')) elif isinstance(layer, nn.Conv2d): layer = nn.Conv2d( in_channels=layer.out_channels, out_channels=layer.in_channels, kernel_size=layer.kernel_size ) self.decoder.append(layer) if weight_path: self.decoder.load_state_dict(torch.load(weight_path, map_location="cpu")) def forward(self, x): for layer in self.decoder: x = layer(x) return x
Адаптивная нормализация экземпляра (AdaIN)
Теперь о слое AdaIN между кодером и декодером, который составляет основу этого метода передачи стиля.
Формулу для AdaIN можно увидеть ниже, где x представляет функции контента, а y представляет функции стиля;
Шаги следующие;
- Вычислите среднее значение по каналам и стандартное отклонение как для содержания, так и для стилевых характеристик.
- Нормализуйте функции контента, чтобы иметь 0 среднего и 1 стандартное отклонение.
- Перестроены нормализованные функции контента, чтобы они имели то же среднее значение и стандартное отклонение, что и функции стиля.
Это выглядит следующим образом в коде;
def AdaIN_realign(style, content): # flatten images while retaining batchs and channels, compute mean and std B, C = content.shape[0], content.shape[1] content_mean = content.view(B, C, -1).mean(dim=2) content_std = content.view(B, C, -1).std(dim=2) style_mean = style.view(B, C, -1).mean(dim=2) style_std = style.view(B, C, -1).std(dim=2) # reshape mean and std to perform normalization and realignment content_mean, content_std = content_mean.view(B, C, 1, 1), content_std.view(B, C, 1, 1) style_mean, style_std = style_mean.view(B, C, 1, 1), style_std.view(B, C, 1, 1) content_mean, content_std = content_mean.expand(content.size()), content_std.expand(content.size()) style_mean, style_std = style_mean.expand(style.size()), style_std.expand(style.size()) # normalize content features. Small constant added to avoid zero division. normalized = (content - content_mean) / (content_std + 0.00001) # realign normalized content features with style mean and std realigned = (normalized * style_std) + style_mean return realigned
Функции потерь
Как и в случае с другими подходами к передаче нейронного стиля, функция потерь состоит из двух частей; потеря содержания и потеря стиля.
Потеря контента
Потеря контента — это просто среднеквадратичная потеря ошибки между элементами контента с повторным выравниванием (выходные данные слоя AdaIN перед прохождением через декодер) и извлеченными функциями из стилизованного изображения, как показано на диаграмме обучения. установка выше.
Потеря стиля
Следующие шаги описывают расчет потери стиля.
- Изображение стиля и сгенерированное изображение передаются через кодировщик для извлечения признаков.
- Для каждого рассчитывается среднее значение по каналу и стандартное отклонение.
- Среднеквадратическая ошибка между средним значением и стандартным отклонением берется и суммируется. Это потеря стиля.
Код для потери содержимого и стиля можно увидеть ниже. Это в файле loss.py в репозитории.
from torch import nn class AdaINLoss: def __init__(self, enc, style_weight) -> None: self.mse = nn.MSELoss() self.loss_network = enc self.style_weight = style_weight # a simple mse for the content loss def content_loss(self, realigned_content, pred_feature_last): return self.mse(realigned_content, pred_feature_last) # mse of the channel-wise mean and std for style loss def style_loss(self, style_features, pred_features): style_loss = 0 for s_ft, p_ft in zip(style_features, pred_features): B, C = s_ft.shape[0], s_ft.shape[1] s_ft, p_ft = s_ft.view(B, C, -1), p_ft.view(B, C, -1) mean_loss = self.mse(s_ft.mean(dim=2), p_ft.mean(dim=2)) std_loss = self.mse(s_ft.std(dim=2), p_ft.std(dim=2)) style_loss += (mean_loss + std_loss) return style_loss def calculate_loss(self, style_features, pred_img, realigned_content): pred_features = self.loss_network(pred_img) content_loss = self.content_loss(realigned_content, pred_features[-1]) style_loss = self.style_loss(style_features, pred_features) return content_loss + (style_loss * self.style_weight)
Тренировочный цикл
Цикл обучения аналогичен предыдущим статьям о переносе нейронного стиля. Код можно увидеть ниже и в файле train.py в репозитории.
from model import VGGEncoder, VGGDecoder, AdaIN_realign from loss import AdaINLoss from dataset import AdaINDataset from torch import optim import torch from torchvision import transforms from tqdm import tqdm from PIL import Image class TrainAdaIN: def __init__(self, epochs, style_weight, lr, batch_size, content_path, style_path, test_content, test_style, dev, enc_weight_path=None, dec_weight_path=None, show_test_output=False) -> None: self.dev = dev self.style_weight = style_weight self.batch_size = batch_size self.lr = lr self.epochs = epochs self.show_test_output = show_test_output self.load_weights(enc_weight_path, dec_weight_path) self.adain_loss = AdaINLoss(self.enc, self.style_weight) self.dataset = AdaINDataset(content_path, style_path, batch_size) self.test_images_init(test_content, test_style) # load pretrained weights def load_weights(self, enc_weight_path, dec_weight_path): if enc_weight_path: self.enc = VGGEncoder(weight_path=enc_weight_path).to(self.dev) else: self.enc = VGGEncoder().to(self.dev) if dec_weight_path: self.dec = VGGDecoder(self.enc, weight_path=dec_weight_path).to(self.dev) else: self.dec = VGGDecoder(encoder=self.enc).to(self.dev) # initialize test images to show training progree def test_images_init(self, content_path, style_path): self.content_test = Image.open(content_path) self.style_test = Image.open(style_path).resize(self.content_test.size) T = transforms.Compose([ transforms.Resize(512), transforms.ToTensor() ]) self.to_pil = transforms.ToPILImage() self.content_test, self.style_test = T(self.content_test), T(self.style_test) self.content_test, self.style_test = self.content_test.unsqueeze(0), self.style_test.unsqueeze(0) self.content_test, self.style_test = self.content_test.to(self.dev), self.style_test.to(self.dev) self.content_test = self.enc(self.content_test) self.style_test = self.enc(self.style_test) self.realigned_content_test = AdaIN_realign(self.style_test[-1], self.content_test[-1]) # save weights every few iterations def save_checkpoint(self, save_path = "weights/dec.pth"): torch.save(self.dec.state_dict(), save_path) # show test outputs every few iterations def show_output(self): test_pred_img = self.dec(self.realigned_content_test)[0].clip(0, 1) test_pred_img = self.to_pil(test_pred_img) test_pred_img.save("output.jpg") def train(self): opt = optim.Adam(self.dec.parameters(), lr=self.lr) for e in range(1, self.epochs + 1): loop = tqdm( enumerate(zip(self.dataset.content_loader, self.dataset.style_loader)), total=len(self.dataset.style_loader), leave=False, position=0 ) loop.set_description(f"Epoch - {e} | ") for i, ((content, _), (style, _)) in loop: content, style = content.to(self.dev), style.to(self.dev) opt.zero_grad() self.dec.train() content_features = self.enc(content) style_features = self.enc(style) realigned_content = AdaIN_realign(style_features[-1], content_features[-1]) pred_img = self.dec(realigned_content).clip(0,1) loss = self.adain_loss.calculate_loss(style_features, pred_img, realigned_content) loss.backward() opt.step() loop.set_postfix(loss=loss.item()) if i % 10 == 0: self.save_checkpoint() if self.show_test_output: self.dec.eval() self.show_output()
Полученные результаты
Если вы хотите сэкономить время на тренировке, предварительно тренированные веса можно найти здесь. Вот пример нескольких стилизованных изображений.