У меня есть два обучающих скрипта на Python. Один использует API Pytorch для обучения классификации, а другой использует Fast-ai. Использование Fast-ai дает гораздо лучшие результаты.
Результаты обучения следующие.
Fastai
epoch train_loss valid_loss accuracy time
0 0.205338 2.318084 0.466482 23:02
1 0.182328 0.041315 0.993334 22:51
2 0.112462 0.064061 0.988932 22:47
3 0.052034 0.044727 0.986920 22:45
4 0.178388 0.081247 0.980883 22:45
5 0.009298 0.011817 0.996730 22:44
6 0.004008 0.003211 0.999748 22:43
Using Pytorch
Epoch [1/10], train_loss : 31.0000 , val_loss : 1.6594, accuracy: 0.3568
Epoch [2/10], train_loss : 7.0000 , val_loss : 1.7065, accuracy: 0.3723
Epoch [3/10], train_loss : 4.0000 , val_loss : 1.6878, accuracy: 0.3889
Epoch [4/10], train_loss : 3.0000 , val_loss : 1.7054, accuracy: 0.4066
Epoch [5/10], train_loss : 2.0000 , val_loss : 1.7154, accuracy: 0.4106
Epoch [6/10], train_loss : 2.0000 , val_loss : 1.7232, accuracy: 0.4144
Epoch [7/10], train_loss : 2.0000 , val_loss : 1.7125, accuracy: 0.4295
Epoch [8/10], train_loss : 1.0000 , val_loss : 1.7372, accuracy: 0.4343
Epoch [9/10], train_loss : 1.0000 , val_loss : 1.6871, accuracy: 0.4441
Epoch [10/10], train_loss : 1.0000 , val_loss : 1.7384, accuracy: 0.4552
Использование Pytorch не сходится. Я использовал ту же сеть (Wideresnet22), и обе обучались с нуля без предварительно обученной модели.
Сеть находится здесь.
Обучение использованию Pytorch находится здесь.
Использование Fastai происходит следующим образом.
from fastai.basic_data import DataBunch
from fastai.train import Learner
from fastai.metrics import accuracy
#DataBunch takes data and internall create data loader
data = DataBunch.create(train_ds, valid_ds, bs=batch_size, path='./data')
#Learner uses Adam as default for learning
learner = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])
#Gradient is clipped
learner.clip = 0.1
#learner finds its learning rate
learner.lr_find()
learner.recorder.plot()
#Weight decay helps to lower down weight. Learn in https://towardsdatascience.com/
learner.fit_one_cycle(5, 5e-3, wd=1e-4)
Что может быть не так в моем алгоритме обучения с использованием Pytorch?