Разбираем реализацию HuggingFace ViT

Vision Transformer (ViT) стал важной вехой в развитии компьютерного зрения. ViT бросает вызов общепринятому мнению о том, что изображения лучше всего обрабатываются с помощью сверточных слоев, доказывая, что механизмы внимания, основанные на последовательностях, могут эффективно улавливать сложные закономерности, контекст и семантику, присутствующие в изображениях. Разбивая изображения на управляемые фрагменты и используя самообладание, ViT фиксирует как локальные, так и глобальные взаимосвязи, что позволяет ему преуспеть в различных задачах машинного зрения, от классификации изображений до обнаружения объектов и за его пределами. В этой статье мы разберем, как работает ViT для классификации.

Введение

Основная идея ViT — рассматривать изображение как последовательность фрагментов фиксированного размера, которые затем сглаживаются и преобразуются в одномерные векторы. Эти патчи впоследствии обрабатываются кодировщиком-трансформером, который позволяет модели фиксировать глобальный контекст и зависимости по всему изображению. Разделяя изображение на фрагменты, ViT эффективно снижает вычислительную сложность обработки больших изображений, сохраняя при этом способность моделировать сложные пространственные взаимодействия.

Прежде всего, мы импортируем модель ViT для классификации из библиотеки трансформеров обнимающих лиц:

from transformers import ViTForImageClassification
import torch
import numpy as np

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

patch16–224 указывает, что модель принимает изображения размером 224x224, а ширина и высота каждого патча составляет 16 пикселей.

Вот как выглядит архитектура модели:

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key)…