Попробуйте модель классификации изображений с несбалансированным набором данных и улучшите ее точность с помощью методов увеличения данных.
Мы создадим модель классификации изображений из минимального и несбалансированного набора данных, а затем воспользуемся методами увеличения данных, чтобы сбалансировать и сравнить результаты.
Набор данных
Наш набор данных содержит 200 изображений цветов и 20 изображений птиц в соотношении 1:10. Чтобы сформировать этот набор данных, мы использовали методы загрузки URL-адресов изображений через поисковую систему Google®, как описано шаг за шагом в последней статье.
Загрузка и проверка изображений
Как только у нас будет список URL-адресов для каждой категории в наших CSV-файлах, мы запустим код для загрузки фотографий и создания нашего набора данных.
Путь «data/data_aug» является для нас базовым каталогом. В этом каталоге мы поместили два CSV-файла категорий; давайте выполним код для создания подкаталогов, куда будут загружены изображения, и проверим их.
%reload_ext autoreload %autoreload 2 %matplotlib inline from fastai.vision import * classes = ['birds','flowers'] path = Path('data/data_aug') #creating folders for folder in classes: dest = path/folder dest.mkdir(parents=True, exist_ok=True) #downloading images from urls in csv file = 'urls_'+'birds''.csv' dest = path/'birds' download_images(path/file, dest, max_pics=20) file = 'urls_'+'flowers''.csv' dest = path/'flowers' download_images(path/file, dest, max_pics=200) #verifying images for c in classes: print(c) verify_images(path/c, delete=True, max_size=500)
Создание и визуализация набора данных
После того, как фотографии загружены и проверены в каталогах, соответствующих каждой категории, мы можем создать fast.ai DataFrame, чтобы иметь возможность помещать в него изображения с тегами и начинать визуализацию и работу с ними. Мы оставляем за собой 20% для подтверждения набора.
np.random.seed(7) data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
И запустите код, чтобы отобразить случайный пакет из 3 строк:
data_gt.show_batch(rows=3, figsize=(7,8))
Модель классификации
В этом упражнении мы будем использовать свёрточную сеть формата resnet34¹.
В «моделях» есть набор предопределенных сетевых архитектур, включающих различные структуры и сложности².
learn_gt = cnn_learner(data_gt, models.resnet34, metrics=error_rate) learn_gt.fit_one_cycle(4)
learn_gt.save('gt_stage-1') learn_gt.load('gt_stage-1') learn_gt.unfreeze() learn_gt.lr_find()
learn_gt.fit_one_cycle(2, max_lr=slice(1e-5,1e-2)) learn_gt.save('gt_stage-2') learn_gt.load('gt_stage-2')
Результаты в матрице путаницы
interp = ClassificationInterpretation.from_learner(learn_gt) interp.plot_confusion_matrix()
Как мы видим, модель очень неэффективна в прогнозировании класса
Менее представленный класс (птицы), где существует много путаницы с проверочным набором.
Увеличение данных
Мы должны создать 180 новых изображений в категории птиц. Для этого мы пройдемся по каждому из реальных изображений, создав десять дополнительных изображений для каждого, используя метод apply_tfms fast.ai из fast.ai.
path = Path('data/data_aug') path_hr = path/'birds' il = ImageList.from_folder(path_hr) tfms = get_transforms(max_rotate=25) def data_aug_one(ex_img,prox,qnt): for lop in range(0,qnt): image_name = str(prox).zfill(8) +'.jpg' dest = path_hr/image_name prox = prox + 1 new_img = open_image(ex_img) new_img_fin = new_img.apply_tfms(tfms[0], new_img, xtra={tfms[1][0].tfm: {"size": 224}}, size=224) new_img_fin.save(dest) prox = 20 qnt = 10 for imagen in il.items: data_aug_one(imagen,prox,qnt) prox = prox + qnt
Если мы визуализируем любое из исходных изображений и их десять новых изображений, мы обнаружим такие вещи:
Подробно о работе функции преобразования apply_tfms и всех ее возможностях можно прочитать здесь.
Та же модель со сбалансированными данными
Мы создали новую модель со сбалансированными наборами данных
np.random.seed(7) tfms = get_transforms() data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=tfms, size=224, num_workers=4).normalize(imagenet_stats) learn = cnn_learner(data, models.resnet34, metrics=error_rate) learn.fit_one_cycle(4)
interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix()
Хорошо, теперь давайте представим новый набор изображений
Мы взяли двадцать новых изображений, по десять из каждой категории, и видим, что точность, хотя и значительно улучшенная по сравнению с исходной версией, по-прежнему вызывает некоторые проблемы при категоризации. Из 20 новых изображений модель предсказала 18 правильных и два неверных:
path = Path('data/data_aug_test') defaults.device = torch.device('cpu') img = open_image(path/'bird01.jpg') img
pred_class,pred_idx,outputs = learn.predict(img) pred_class
Категория птицы
for test in range(1,11): image_name = 'bird'+str(test).zfill(2)+'.jpg' img = open_image(path/image_name) pred_class,pred_idx,outputs = learn.predict(img) print ('For image ' + image_name + ' predicted class: '); print (pred_class) image_name = 'flower'+str(test).zfill(2)+'.jpg' img = open_image(path/image_name) pred_class,pred_idx,outputs = learn.predict(img) print ('For image ' + image_name + ' predicted class: '); print (pred_class)
Изображение flower07.jpg предсказано как птица, а изображение bird08.jpg предсказано как цветок.
Здесь изображения путаницы:
Резюме
Как мы могли видеть, разница в обучении одной и той же модели на сбалансированном наборе данных и несбалансированном фундаментальна; однако, если набор данных минимален, возможно, достигнутой нами точности недостаточно.
В предыдущей статье мы создали очень похожую модель из 200 изображений каждой категории. Когда мы представили ему те же 20 новых тестовых изображений, у него была такая же путаница на тех же фотографиях.
Мы использовали метод apply_tfms из fast.ai для преобразования нескольких исходных изображений птиц, создав набор данных в 10 раз больше.
Источники и ссылки
[1] — https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035
В упомянутом курсе fast.ai эта техника вдохновлена: Франциско Ингамом и Джереми Ховардом / [Адриан Роузброк]
https://course.fast.ai, урок 2, Jupyter Notebook: урок2-скачать