Попробуйте модель классификации изображений с несбалансированным набором данных и улучшите ее точность с помощью методов увеличения данных.

Мы создадим модель классификации изображений из минимального и несбалансированного набора данных, а затем воспользуемся методами увеличения данных, чтобы сбалансировать и сравнить результаты.

Набор данных

Наш набор данных содержит 200 изображений цветов и 20 изображений птиц в соотношении 1:10. Чтобы сформировать этот набор данных, мы использовали методы загрузки URL-адресов изображений через поисковую систему Google®, как описано шаг за шагом в последней статье.

Загрузка и проверка изображений

Как только у нас будет список URL-адресов для каждой категории в наших CSV-файлах, мы запустим код для загрузки фотографий и создания нашего набора данных.
Путь «data/data_aug» является для нас базовым каталогом. В этом каталоге мы поместили два CSV-файла категорий; давайте выполним код для создания подкаталогов, куда будут загружены изображения, и проверим их.

%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
classes = ['birds','flowers']
path = Path('data/data_aug')
#creating folders
for folder in classes:
   dest = path/folder
   dest.mkdir(parents=True, exist_ok=True)
#downloading images from urls in csv
file = 'urls_'+'birds''.csv'
dest = path/'birds'
download_images(path/file, dest, max_pics=20)
file = 'urls_'+'flowers''.csv'
dest = path/'flowers'
download_images(path/file, dest, max_pics=200)
#verifying images
for c in classes:
   print(c)
   verify_images(path/c, delete=True, max_size=500)

Создание и визуализация набора данных

После того, как фотографии загружены и проверены в каталогах, соответствующих каждой категории, мы можем создать fast.ai DataFrame, чтобы иметь возможность помещать в него изображения с тегами и начинать визуализацию и работу с ними. Мы оставляем за собой 20% для подтверждения набора.

np.random.seed(7)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

И запустите код, чтобы отобразить случайный пакет из 3 строк:

data_gt.show_batch(rows=3, figsize=(7,8))

Модель классификации

В этом упражнении мы будем использовать свёрточную сеть формата resnet34¹.
В «моделях» есть набор предопределенных сетевых архитектур, включающих различные структуры и сложности².

learn_gt = cnn_learner(data_gt, models.resnet34, metrics=error_rate)
learn_gt.fit_one_cycle(4)

learn_gt.save('gt_stage-1')
learn_gt.load('gt_stage-1')
learn_gt.unfreeze()
learn_gt.lr_find()

learn_gt.fit_one_cycle(2, max_lr=slice(1e-5,1e-2))
learn_gt.save('gt_stage-2')
learn_gt.load('gt_stage-2')

Результаты в матрице путаницы

interp = ClassificationInterpretation.from_learner(learn_gt)
interp.plot_confusion_matrix()

Как мы видим, модель очень неэффективна в прогнозировании класса
Менее представленный класс (птицы), где существует много путаницы с проверочным набором.

Увеличение данных

Мы должны создать 180 новых изображений в категории птиц. Для этого мы пройдемся по каждому из реальных изображений, создав десять дополнительных изображений для каждого, используя метод apply_tfms fast.ai из fast.ai.

path = Path('data/data_aug')
path_hr = path/'birds'
il = ImageList.from_folder(path_hr)
tfms = get_transforms(max_rotate=25)
def data_aug_one(ex_img,prox,qnt):
   for lop in range(0,qnt):
      image_name = str(prox).zfill(8) +'.jpg'
      dest = path_hr/image_name
      prox = prox + 1
      new_img = open_image(ex_img)
      new_img_fin = new_img.apply_tfms(tfms[0], new_img, xtra={tfms[1][0].tfm: {"size": 224}}, size=224)
      new_img_fin.save(dest)
prox = 20
qnt = 10
for imagen in il.items:
   data_aug_one(imagen,prox,qnt)
   prox = prox + qnt

Если мы визуализируем любое из исходных изображений и их десять новых изображений, мы обнаружим такие вещи:

Подробно о работе функции преобразования apply_tfms и всех ее возможностях можно прочитать здесь.

Та же модель со сбалансированными данными

Мы создали новую модель со сбалансированными наборами данных

np.random.seed(7)
tfms = get_transforms()
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=tfms, size=224, num_workers=4).normalize(imagenet_stats)
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Хорошо, теперь давайте представим новый набор изображений

Мы взяли двадцать новых изображений, по десять из каждой категории, и видим, что точность, хотя и значительно улучшенная по сравнению с исходной версией, по-прежнему вызывает некоторые проблемы при категоризации. Из 20 новых изображений модель предсказала 18 правильных и два неверных:

path = Path('data/data_aug_test')
defaults.device = torch.device('cpu')
img = open_image(path/'bird01.jpg')
img

pred_class,pred_idx,outputs = learn.predict(img)
pred_class

Категория птицы

for test in range(1,11):
   image_name = 'bird'+str(test).zfill(2)+'.jpg'
   img = open_image(path/image_name)
   pred_class,pred_idx,outputs = learn.predict(img)
   print ('For image ' + image_name + ' predicted class: ');
   print (pred_class)
   image_name = 'flower'+str(test).zfill(2)+'.jpg'
   img = open_image(path/image_name)
   pred_class,pred_idx,outputs = learn.predict(img)
   print ('For image ' + image_name + ' predicted class: ');
   print (pred_class)

Изображение flower07.jpg предсказано как птица, а изображение bird08.jpg предсказано как цветок.

Здесь изображения путаницы:

Резюме

Как мы могли видеть, разница в обучении одной и той же модели на сбалансированном наборе данных и несбалансированном фундаментальна; однако, если набор данных минимален, возможно, достигнутой нами точности недостаточно.

В предыдущей статье мы создали очень похожую модель из 200 изображений каждой категории. Когда мы представили ему те же 20 новых тестовых изображений, у него была такая же путаница на тех же фотографиях.

Мы использовали метод apply_tfms из fast.ai для преобразования нескольких исходных изображений птиц, создав набор данных в 10 раз больше.

Источники и ссылки

[1] — https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035

[2] — https://medium.com/@14prakash/understanding-and-implementing-architectures-of-resnet-and-resnext-for-state-of-the-art-image-cf51669e1624



В упомянутом курсе fast.ai эта техника вдохновлена: Франциско Ингамом и Джереми Ховардом / [Адриан Роузброк]

https://course.fast.ai, урок 2, Jupyter Notebook: урок2-скачать