Предыдущий ‹‹ Введение в PyTorch (5/7)
PyTorch имеет два примитива для работы с данными: torch.utils.data.DataLoader и torch.utils.data.Dataset. Dataset хранит образцы и соответствующие им метки, а DataLoader оборачивает набор данных в итерацию.
%matplotlib inline import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor, Lambda, Compose import matplotlib.pyplot as plt
PyTorch предлагает библиотеки для конкретной предметной области, такие как TorchText, TorchVision и TorchAudio, каждая из которых включает наборы данных. В этом уроке мы будем использовать набор данных TorchVision.
Модуль torchvision.datasets содержит объекты Dataset для многих реальных данных машинного зрения, таких как CIFAR и COCO. В этом уроке мы будем использовать набор данных FashionMNIST. Каждый набор данных TorchVision включает два аргумента: transform и target_transform для изменения образцов и меток соответственно.
# Download training data from open datasets. training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) # Download test data from open datasets. test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), )
Мы передаем Набор данных в качестве аргумента DataLoader. Это оборачивает итерацию нашего набора данных и поддерживает автоматическую пакетную обработку, выборку, перетасовку и многопроцессную загрузку данных. Здесь мы определяем размер пакета 64, т. е. каждый элемент в итерируемом загрузчике данных будет возвращать пакет из 64 объектов и меток.
batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=batch_size) for X, y in test_dataloader: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break # Display sample data figure = plt.figure(figsize=(10, 8)) cols, rows = 5, 5 for i in range(1, cols * rows + 1): idx = torch.randint(len(test_data), size=(1,)).item() img, label = test_data[idx] figure.add_subplot(rows, cols, i) plt.title(label) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") plt.show() Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) Shape of y: torch.Size([64]) torch.int64
Создание моделей
Чтобы определить нейронную сеть в PyTorch, мы создаем класс, наследуемый от nn.Module. Мы определяем уровни сети в функции __init__ и указываем, как данные будут проходить через сеть, в функции forward. Для ускорения работы нейронной сети мы перемещаем ее на графический процессор, если он доступен.
# Get cpu or gpu device for training. device = "cuda" if torch.cuda.is_available() else "cpu" print("Using {} device".format(device)) # Define model class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), nn.ReLU() ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork().to(device) print(model) Using cuda device NeuralNetwork( (flatten): Flatten() (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) (5): ReLU() ) )
Оптимизируйте параметры модели
Для обучения модели нам нужна функция потерь и оптимизатор. Мы будем использовать nn.CrossEntropyLoss для потерь и Стохастический градиентный спуск для оптимизации.
loss_fn = nn.CrossEntropyLoss() learning_rate = 1e-3 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
В одном цикле обучения модель делает прогнозы по набору обучающих данных (подаваемых в него пакетами) и осуществляет обратное распространение ошибки прогноза для корректировки параметров модели.
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X)
Мы также можем проверить производительность модели на тестовом наборе данных, чтобы убедиться, что она обучается.
def test(dataloader, model): size = len(dataloader.dataset) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= size correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
Процесс обучения проводится в несколько итераций (эпох). В течение каждой эпохи модель изучает параметры, чтобы делать более точные прогнозы. Мы печатаем точность и потери модели в каждую эпоху; нам бы хотелось, чтобы точность увеличивалась, а потери уменьшались с каждой эпохой.
epochs = 15 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model) print("Done!") Epoch 1 ------------------------------- loss: 2.295450 [ 0/60000] loss: 2.293073 [ 6400/60000] loss: 2.278504 [12800/60000] loss: 2.282501 [19200/60000] loss: 2.273211 [25600/60000] loss: 2.258452 [32000/60000] loss: 2.248237 [38400/60000] loss: 2.228594 [44800/60000] loss: 2.240276 [51200/60000] loss: 2.221318 [57600/60000] Test Error: Accuracy: 51.8%, Avg loss: 0.034745 Epoch 2 ------------------------------- loss: 2.212354 [ 0/60000] loss: 2.207739 [ 6400/60000] loss: 2.160400 [12800/60000] loss: 2.176181 [19200/60000] loss: 2.168270 [25600/60000] loss: 2.146453 [32000/60000] loss: 2.119934 [38400/60000] loss: 2.083791 [44800/60000] loss: 2.126453 [51200/60000] loss: 2.077550 [57600/60000] Test Error: Accuracy: 53.2%, Avg loss: 0.032452 Epoch 3 ------------------------------- loss: 2.082280 [ 0/60000] loss: 2.068733 [ 6400/60000] loss: 1.965958 [12800/60000] loss: 1.997126 [19200/60000] loss: 2.002057 [25600/60000] loss: 1.967370 [32000/60000] loss: 1.910595 [38400/60000] loss: 1.849006 [44800/60000] loss: 1.944741 [51200/60000] loss: 1.861265 [57600/60000] Test Error: Accuracy: 51.6%, Avg loss: 0.028937 Epoch 4 ------------------------------- loss: 1.872628 [ 0/60000] loss: 1.844543 [ 6400/60000] loss: 1.710179 [12800/60000] loss: 1.779804 [19200/60000] loss: 1.737971 [25600/60000] loss: 1.746953 [32000/60000] loss: 1.624768 [38400/60000] loss: 1.575720 [44800/60000] loss: 1.742827 [51200/60000] loss: 1.653375 [57600/60000] Test Error: Accuracy: 58.4%, Avg loss: 0.025570 Epoch 5 ------------------------------- loss: 1.662315 [ 0/60000] loss: 1.636235 [ 6400/60000] loss: 1.508407 [12800/60000] loss: 1.606842 [19200/60000] loss: 1.560728 [25600/60000] loss: 1.606024 [32000/60000] loss: 1.426900 [38400/60000] loss: 1.406240 [44800/60000] loss: 1.619918 [51200/60000] loss: 1.521326 [57600/60000] Test Error: Accuracy: 61.2%, Avg loss: 0.023459 Epoch 6 ------------------------------- loss: 1.527535 [ 0/60000] loss: 1.511209 [ 6400/60000] loss: 1.377129 [12800/60000] loss: 1.494889 [19200/60000] loss: 1.457990 [25600/60000] loss: 1.502333 [32000/60000] loss: 1.291539 [38400/60000] loss: 1.285098 [44800/60000] loss: 1.484891 [51200/60000] loss: 1.414015 [57600/60000] Test Error: Accuracy: 62.2%, Avg loss: 0.021480 Epoch 7 ------------------------------- loss: 1.376779 [ 0/60000] loss: 1.384830 [ 6400/60000] loss: 1.230116 [12800/60000] loss: 1.382574 [19200/60000] loss: 1.255630 [25600/60000] loss: 1.396211 [32000/60000] loss: 1.157718 [38400/60000] loss: 1.186382 [44800/60000] loss: 1.340606 [51200/60000] loss: 1.321607 [57600/60000] Test Error: Accuracy: 62.8%, Avg loss: 0.019737 Epoch 8 ------------------------------- loss: 1.243344 [ 0/60000] loss: 1.279124 [ 6400/60000] loss: 1.121769 [12800/60000] loss: 1.293069 [19200/60000] loss: 1.128232 [25600/60000] loss: 1.315465 [32000/60000] loss: 1.069528 [38400/60000] loss: 1.123324 [44800/60000] loss: 1.243827 [51200/60000] loss: 1.255190 [57600/60000] Test Error: Accuracy: 63.4%, Avg loss: 0.018518 Epoch 9 ------------------------------- loss: 1.154148 [ 0/60000] loss: 1.205280 [ 6400/60000] loss: 1.046463 [12800/60000] loss: 1.229866 [19200/60000] loss: 1.048813 [25600/60000] loss: 1.254785 [32000/60000] loss: 1.010614 [38400/60000] loss: 1.077114 [44800/60000] loss: 1.176766 [51200/60000] loss: 1.206567 [57600/60000] Test Error: Accuracy: 64.3%, Avg loss: 0.017640 Epoch 10 ------------------------------- loss: 1.090360 [ 0/60000] loss: 1.149150 [ 6400/60000] loss: 0.990786 [12800/60000] loss: 1.183704 [19200/60000] loss: 0.997114 [25600/60000] loss: 1.207199 [32000/60000] loss: 0.967512 [38400/60000] loss: 1.043431 [44800/60000] loss: 1.127000 [51200/60000] loss: 1.169639 [57600/60000] Test Error: Accuracy: 65.3%, Avg loss: 0.016974 Epoch 11 ------------------------------- loss: 1.041194 [ 0/60000] loss: 1.104409 [ 6400/60000] loss: 0.947670 [12800/60000] loss: 1.149421 [19200/60000] loss: 0.960403 [25600/60000] loss: 1.169899 [32000/60000] loss: 0.935149 [38400/60000] loss: 1.018250 [44800/60000] loss: 1.088222 [51200/60000] loss: 1.139813 [57600/60000] Test Error: Accuracy: 66.2%, Avg loss: 0.016446 Epoch 12 ------------------------------- loss: 1.000646 [ 0/60000] loss: 1.067356 [ 6400/60000] loss: 0.912046 [12800/60000] loss: 1.122742 [19200/60000] loss: 0.932827 [25600/60000] loss: 1.138785 [32000/60000] loss: 0.910242 [38400/60000] loss: 0.999010 [44800/60000] loss: 1.056596 [51200/60000] loss: 1.114582 [57600/60000] Test Error: Accuracy: 67.5%, Avg loss: 0.016011 Epoch 13 ------------------------------- loss: 0.966393 [ 0/60000] loss: 1.035691 [ 6400/60000] loss: 0.881672 [12800/60000] loss: 1.100845 [19200/60000] loss: 0.910265 [25600/60000] loss: 1.112597 [32000/60000] loss: 0.889558 [38400/60000] loss: 0.982751 [44800/60000] loss: 1.029199 [51200/60000] loss: 1.092738 [57600/60000] Test Error: Accuracy: 68.5%, Avg loss: 0.015636 Epoch 14 ------------------------------- loss: 0.936334 [ 0/60000] loss: 1.007734 [ 6400/60000] loss: 0.854663 [12800/60000] loss: 1.081601 [19200/60000] loss: 0.890581 [25600/60000] loss: 1.089641 [32000/60000] loss: 0.872057 [38400/60000] loss: 0.969192 [44800/60000] loss: 1.005193 [51200/60000] loss: 1.073098 [57600/60000] Test Error: Accuracy: 69.4%, Avg loss: 0.015304 Epoch 15 ------------------------------- loss: 0.908971 [ 0/60000] loss: 0.982067 [ 6400/60000] loss: 0.830095 [12800/60000] loss: 1.064921 [19200/60000] loss: 0.874204 [25600/60000] loss: 1.069008 [32000/60000] loss: 0.856447 [38400/60000] loss: 0.957340 [44800/60000] loss: 0.983547 [51200/60000] loss: 1.055251 [57600/60000] Test Error: Accuracy: 70.3%, Avg loss: 0.015001 Done!
Точность изначально будет не очень хорошей (это нормально!). Попробуйте запустить цикл на большее количество эпох или увеличить значение learning_rate. Также может случиться так, что выбранная нами конфигурация модели может оказаться неоптимальной для такого рода проблем.
Сохранить модели
Распространенный способ сохранить модель — сериализовать внутренний словарь состояний (содержащий параметры модели).
torch.save(model.state_dict(), "data/model.pth") print("Saved PyTorch Model State to model.pth")
Загрузить модели
Процесс загрузки модели включает в себя воссоздание структуры модели и загрузку в нее словаря состояний.
model = NeuralNetwork() model.load_state_dict(torch.load("data/model.pth"))
Теперь эту модель можно использовать для прогнозирования.
classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] model.eval() x, y = test_data[0][0], test_data[0][1] with torch.no_grad(): pred = model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted}", Actual: "{actual}"') Predicted: "Ankle boot", Actual: "Ankle boot"
Поздравляем! Вы завершили руководство для начинающих по PyTorch! Мы надеемся, что это руководство помогло вам начать глубокое изучение PyTorch.