Предыдущий ‹‹ Введение в PyTorch (5/7)

PyTorch имеет два примитива для работы с данными: torch.utils.data.DataLoader и torch.utils.data.Dataset. Dataset хранит образцы и соответствующие им метки, а DataLoader оборачивает набор данных в итерацию.

%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

PyTorch предлагает библиотеки для конкретной предметной области, такие как TorchText, TorchVision и TorchAudio, каждая из которых включает наборы данных. В этом уроке мы будем использовать набор данных TorchVision.

Модуль torchvision.datasets содержит объекты Dataset для многих реальных данных машинного зрения, таких как CIFAR и COCO. В этом уроке мы будем использовать набор данных FashionMNIST. Каждый набор данных TorchVision включает два аргумента: transform и target_transform для изменения образцов и меток соответственно.

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Мы передаем Набор данных в качестве аргумента DataLoader. Это оборачивает итерацию нашего набора данных и поддерживает автоматическую пакетную обработку, выборку, перетасовку и многопроцессную загрузку данных. Здесь мы определяем размер пакета 64, т. е. каждый элемент в итерируемом загрузчике данных будет возвращать пакет из 64 объектов и меток.

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
    
# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    idx = torch.randint(len(test_data), size=(1,)).item()
    img, label = test_data[idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()
Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64

Создание моделей

Чтобы определить нейронную сеть в PyTorch, мы создаем класс, наследуемый от nn.Module. Мы определяем уровни сети в функции __init__ и указываем, как данные будут проходить через сеть, в функции forward. Для ускорения работы нейронной сети мы перемещаем ее на графический процессор, если он доступен.

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)
Using cuda device
NeuralNetwork(
  (flatten): Flatten()
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

Оптимизируйте параметры модели

Для обучения модели нам нужна функция потерь и оптимизатор. Мы будем использовать nn.CrossEntropyLoss для потерь и Стохастический градиентный спуск для оптимизации.

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

В одном цикле обучения модель делает прогнозы по набору обучающих данных (подаваемых в него пакетами) и осуществляет обратное распространение ошибки прогноза для корректировки параметров модели.

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)

Мы также можем проверить производительность модели на тестовом наборе данных, чтобы убедиться, что она обучается.

def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Процесс обучения проводится в несколько итераций (эпох). В течение каждой эпохи модель изучает параметры, чтобы делать более точные прогнозы. Мы печатаем точность и потери модели в каждую эпоху; нам бы хотелось, чтобы точность увеличивалась, а потери уменьшались с каждой эпохой.

epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model)
print("Done!")
Epoch 1
-------------------------------
loss: 2.295450  [    0/60000]
loss: 2.293073  [ 6400/60000]
loss: 2.278504  [12800/60000]
loss: 2.282501  [19200/60000]
loss: 2.273211  [25600/60000]
loss: 2.258452  [32000/60000]
loss: 2.248237  [38400/60000]
loss: 2.228594  [44800/60000]
loss: 2.240276  [51200/60000]
loss: 2.221318  [57600/60000]
Test Error: 
 Accuracy: 51.8%, Avg loss: 0.034745 

Epoch 2
-------------------------------
loss: 2.212354  [    0/60000]
loss: 2.207739  [ 6400/60000]
loss: 2.160400  [12800/60000]
loss: 2.176181  [19200/60000]
loss: 2.168270  [25600/60000]
loss: 2.146453  [32000/60000]
loss: 2.119934  [38400/60000]
loss: 2.083791  [44800/60000]
loss: 2.126453  [51200/60000]
loss: 2.077550  [57600/60000]
Test Error: 
 Accuracy: 53.2%, Avg loss: 0.032452 

Epoch 3
-------------------------------
loss: 2.082280  [    0/60000]
loss: 2.068733  [ 6400/60000]
loss: 1.965958  [12800/60000]
loss: 1.997126  [19200/60000]
loss: 2.002057  [25600/60000]
loss: 1.967370  [32000/60000]
loss: 1.910595  [38400/60000]
loss: 1.849006  [44800/60000]
loss: 1.944741  [51200/60000]
loss: 1.861265  [57600/60000]
Test Error: 
 Accuracy: 51.6%, Avg loss: 0.028937 

Epoch 4
-------------------------------
loss: 1.872628  [    0/60000]
loss: 1.844543  [ 6400/60000]
loss: 1.710179  [12800/60000]
loss: 1.779804  [19200/60000]
loss: 1.737971  [25600/60000]
loss: 1.746953  [32000/60000]
loss: 1.624768  [38400/60000]
loss: 1.575720  [44800/60000]
loss: 1.742827  [51200/60000]
loss: 1.653375  [57600/60000]
Test Error: 
 Accuracy: 58.4%, Avg loss: 0.025570 

Epoch 5
-------------------------------
loss: 1.662315  [    0/60000]
loss: 1.636235  [ 6400/60000]
loss: 1.508407  [12800/60000]
loss: 1.606842  [19200/60000]
loss: 1.560728  [25600/60000]
loss: 1.606024  [32000/60000]
loss: 1.426900  [38400/60000]
loss: 1.406240  [44800/60000]
loss: 1.619918  [51200/60000]
loss: 1.521326  [57600/60000]
Test Error: 
 Accuracy: 61.2%, Avg loss: 0.023459 

Epoch 6
-------------------------------
loss: 1.527535  [    0/60000]
loss: 1.511209  [ 6400/60000]
loss: 1.377129  [12800/60000]
loss: 1.494889  [19200/60000]
loss: 1.457990  [25600/60000]
loss: 1.502333  [32000/60000]
loss: 1.291539  [38400/60000]
loss: 1.285098  [44800/60000]
loss: 1.484891  [51200/60000]
loss: 1.414015  [57600/60000]
Test Error: 
 Accuracy: 62.2%, Avg loss: 0.021480 

Epoch 7
-------------------------------
loss: 1.376779  [    0/60000]
loss: 1.384830  [ 6400/60000]
loss: 1.230116  [12800/60000]
loss: 1.382574  [19200/60000]
loss: 1.255630  [25600/60000]
loss: 1.396211  [32000/60000]
loss: 1.157718  [38400/60000]
loss: 1.186382  [44800/60000]
loss: 1.340606  [51200/60000]
loss: 1.321607  [57600/60000]
Test Error: 
 Accuracy: 62.8%, Avg loss: 0.019737 

Epoch 8
-------------------------------
loss: 1.243344  [    0/60000]
loss: 1.279124  [ 6400/60000]
loss: 1.121769  [12800/60000]
loss: 1.293069  [19200/60000]
loss: 1.128232  [25600/60000]
loss: 1.315465  [32000/60000]
loss: 1.069528  [38400/60000]
loss: 1.123324  [44800/60000]
loss: 1.243827  [51200/60000]
loss: 1.255190  [57600/60000]
Test Error: 
 Accuracy: 63.4%, Avg loss: 0.018518 

Epoch 9
-------------------------------
loss: 1.154148  [    0/60000]
loss: 1.205280  [ 6400/60000]
loss: 1.046463  [12800/60000]
loss: 1.229866  [19200/60000]
loss: 1.048813  [25600/60000]
loss: 1.254785  [32000/60000]
loss: 1.010614  [38400/60000]
loss: 1.077114  [44800/60000]
loss: 1.176766  [51200/60000]
loss: 1.206567  [57600/60000]
Test Error: 
 Accuracy: 64.3%, Avg loss: 0.017640 

Epoch 10
-------------------------------
loss: 1.090360  [    0/60000]
loss: 1.149150  [ 6400/60000]
loss: 0.990786  [12800/60000]
loss: 1.183704  [19200/60000]
loss: 0.997114  [25600/60000]
loss: 1.207199  [32000/60000]
loss: 0.967512  [38400/60000]
loss: 1.043431  [44800/60000]
loss: 1.127000  [51200/60000]
loss: 1.169639  [57600/60000]
Test Error: 
 Accuracy: 65.3%, Avg loss: 0.016974 

Epoch 11
-------------------------------
loss: 1.041194  [    0/60000]
loss: 1.104409  [ 6400/60000]
loss: 0.947670  [12800/60000]
loss: 1.149421  [19200/60000]
loss: 0.960403  [25600/60000]
loss: 1.169899  [32000/60000]
loss: 0.935149  [38400/60000]
loss: 1.018250  [44800/60000]
loss: 1.088222  [51200/60000]
loss: 1.139813  [57600/60000]
Test Error: 
 Accuracy: 66.2%, Avg loss: 0.016446 

Epoch 12
-------------------------------
loss: 1.000646  [    0/60000]
loss: 1.067356  [ 6400/60000]
loss: 0.912046  [12800/60000]
loss: 1.122742  [19200/60000]
loss: 0.932827  [25600/60000]
loss: 1.138785  [32000/60000]
loss: 0.910242  [38400/60000]
loss: 0.999010  [44800/60000]
loss: 1.056596  [51200/60000]
loss: 1.114582  [57600/60000]
Test Error: 
 Accuracy: 67.5%, Avg loss: 0.016011 

Epoch 13
-------------------------------
loss: 0.966393  [    0/60000]
loss: 1.035691  [ 6400/60000]
loss: 0.881672  [12800/60000]
loss: 1.100845  [19200/60000]
loss: 0.910265  [25600/60000]
loss: 1.112597  [32000/60000]
loss: 0.889558  [38400/60000]
loss: 0.982751  [44800/60000]
loss: 1.029199  [51200/60000]
loss: 1.092738  [57600/60000]
Test Error: 
 Accuracy: 68.5%, Avg loss: 0.015636 

Epoch 14
-------------------------------
loss: 0.936334  [    0/60000]
loss: 1.007734  [ 6400/60000]
loss: 0.854663  [12800/60000]
loss: 1.081601  [19200/60000]
loss: 0.890581  [25600/60000]
loss: 1.089641  [32000/60000]
loss: 0.872057  [38400/60000]
loss: 0.969192  [44800/60000]
loss: 1.005193  [51200/60000]
loss: 1.073098  [57600/60000]
Test Error: 
 Accuracy: 69.4%, Avg loss: 0.015304 

Epoch 15
-------------------------------
loss: 0.908971  [    0/60000]
loss: 0.982067  [ 6400/60000]
loss: 0.830095  [12800/60000]
loss: 1.064921  [19200/60000]
loss: 0.874204  [25600/60000]
loss: 1.069008  [32000/60000]
loss: 0.856447  [38400/60000]
loss: 0.957340  [44800/60000]
loss: 0.983547  [51200/60000]
loss: 1.055251  [57600/60000]
Test Error: 
 Accuracy: 70.3%, Avg loss: 0.015001 

Done!

Точность изначально будет не очень хорошей (это нормально!). Попробуйте запустить цикл на большее количество эпох или увеличить значение learning_rate. Также может случиться так, что выбранная нами конфигурация модели может оказаться неоптимальной для такого рода проблем.

Сохранить модели

Распространенный способ сохранить модель — сериализовать внутренний словарь состояний (содержащий параметры модели).

torch.save(model.state_dict(), "data/model.pth")
print("Saved PyTorch Model State to model.pth")

Загрузить модели

Процесс загрузки модели включает в себя воссоздание структуры модели и загрузку в нее словаря состояний.

model = NeuralNetwork()
model.load_state_dict(torch.load("data/model.pth"))

Теперь эту модель можно использовать для прогнозирования.

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"

Поздравляем! Вы завершили руководство для начинающих по PyTorch! Мы надеемся, что это руководство помогло вам начать глубокое изучение PyTorch.