Пример проекта обучения и тестирования моделей машинного обучения (ML) с caret
в R для прогнозирования использования противозачаточных средств.
Этот пост посвящен техническому аспекту этого проекта. Если вы хотите просмотреть этот учебный документ, чтобы узнать об общем рабочем процессе, включая стратегию анализа, сравнение моделей и ограничения, просмотрите:
1. Обзор
Введение в статистическое обучение: с приложениями в R или ISLR была моей первой книгой по предиктивной аналитике, и я настоятельно рекомендую всем, кто интересуется машинным обучением, прочитать эту книгу. Я научился программировать на R и использовать различные статистические пакеты, такие как glm
и randomForest
, но это было неэффективно из-за большого количества разных пакетов. К счастью, доступно несколько библиотек, которые пытаются упростить процесс построения прогнозных моделей. Здесь я сосредотачиваюсь на пакете caret
(сокращение от C классификация A nd RE gression T raining), созданном Макс Кун.
Цель проекта - построить классификационные модели для прогнозирования использования противозачаточных средств в Таиланде, Монголии и Лаосе. Данные взяты из Кластерных обследований по множественным показателям (MICS), опубликованных Детским фондом ООН (ЮНИСЕФ). Поскольку они представляют собой микроданные для данных индивидуальных ответов в опросах и переписи, пользователи должны зарегистрироваться, чтобы получить доступ к базе данных. После быстрой регистрации я загрузил 6-й набор данных MICS для трех стран и выполнил очистку и объединение данных. В результате набор данных содержит 58 356 наблюдений и 9 переменных:
- В настоящее время используется контрацепция,
use
: двоичный (да / нет) - Возраст,
age
: Числовой - Высший уровень образования,
edu
: Категориальный (ниже начального / начального / неполного среднего / полного среднего / высшего образования) - Процентиль богатства для конкретной страны,
wealth
: числовой - Семейное положение,
mstat
: категорическое (никогда / бывшее / текущее) - Жилой дом,
urban
: категориальный (городской / сельский) - Страна,
country
: Категория (Таиланд / Монголия / Лаос) - Когда-либо родившие,
given_birth
: двоичный (да / нет) - Когда-либо был ребенок или дети, которые позже умерли,
child_died
: двоичный (да / нет)
Импорт данных и пакетов
library(tidyverse) # data manipulation library(caret) # predictive modelling library(rpart.plot) # decision tree visualisation mics <- read_dta("MICS.dta")
Все категориальные переменные представляют собой текстовые строки или символы, с которыми модели машинного обучения не могут работать. Поэтому я закодировал или факторизовал категориальные функции.
# Factorise variables mics$mstat <- factor(mics$mstat) mics$edu <- factor(mics$edu) mics$country <- factor(mics$country) # Factor recode for clarity mics$use <- factor(mics$use, levels = c(1, 0), labels = c("Using", "Not Using")) mics$residence <- factor(mics$residence, levels = c(1, 0), labels = c("Urban", "Rural")) mics$given_birth <- factor(mics$given_birth, levels = c(1, 0), labels = c("Yes", "No")) mics$child_died <- factor(mics$child_died, levels = c(1, 0), labels = c("Yes", "No")) mics$edu <- factor(mics$edu, levels = c("PRE-PRIMARY OR NONE", "PRIMARY", "LOWER SECONDARY", "UPPER SECONDARY", "HIGHER"), labels = c("Less than Primary", "Primary", "Lower Secondary", "Upper Secondary", "Higher Education")) mics$country <- factor(mics$country, levels = c("THAILAND", "LAO", "MONGOLIA"), labels = c("Thailand", "Laos", "Mongolia"))
Структура кадра данных после перекодирования следующая:
tibble [58,356 × 10] (S3: tbl_df/tbl/data.frame) $ age : num [1:58356] 29 35 36 24 34 38 24 16 15 36 ... $ edu : Factor w/ 5 levels "Less than Primary",..: 5 5 5 5 5 $ mstat : Factor w/ 3 levels "Current","Former",..: 1 1 3 3 1 $ wealth : num [1:58356] 9 9 5 8 10 4 10 7 7 8 ... $ residence : Factor w/ 2 levels "Urban","Rural": 1 1 1 1 1 1 1 1 $ country : Factor w/ 3 levels "Thailand","Laos",..: 1 1 1 1 1 1 $ given_birth: Factor w/ 2 levels "Yes","No": 1 1 2 2 1 1 2 2 2 1 $ child_died : Factor w/ 2 levels "Yes","No": 2 2 2 2 2 2 2 2 2 2 $ use : Factor w/ 2 levels "Using","Not Using": 1 1 2 2 1 1
Сплит-набор для обучения и тестирования
Я отложил 15% наблюдений для набора тестирования, который зарезервирован для окончательного тестирования после обучения и оптимизации моделей. Остальные 85% используются для разработки классификационных моделей. Поскольку я буду экспериментировать с разными параметрами, я также использую 10-кратную перекрестную проверку обучающего набора для оценки производительности.
Первое разделение обучения / тестирования выполняется с помощью команды createDataPartition
, которая создает сбалансированные разделения данных в соответствии с результатом.
# Split data into testing and training train_index <- createDataPartition(mics$use, # 85% for training p = .85, times = 1, list = FALSE) micsTrain <- mics[ train_index, ] # Training micsTest <- mics[-train_index, ] # Testing
10-кратная перекрестная проверка настроена с createFolds
и trainControl
. Первый разбивает обучающий набор на десять сверток, а второй определяет перекрестную проверку с использованием сверток. Обычно достаточно простого trainControl (method="cv", k=10)
, но результат может отличаться каждый раз при выполнении команды. В то время как trainControl
предоставляет параметр seed
для воспроизводимости, у меня возникли проблемы с его настройкой, и я решил использовать createFolds
.
# 10-folds fold_index <- createFolds(micsTrain$use, # number of folds k = 10, # return as list list = T, # return numbers corresponding positions returnTrain = T) # Cross validation ctrl <- trainControl(method="cv", index = fold_index)
Подготовка данных завершена, и я готов обучать модели. Обратите внимание, что подготовка данных включает в себя гораздо более сложные шаги в реальном приложении, такие как работа с недостающими данными, выбор переменных, инженерные особенности, масштабирование и центрирование переменных и т. Д. Однако, поскольку это был мой первый Строя модели от начала до конца, я не выполнял указанные шаги. Более того, большая часть подготовки данных работает с числовыми функциями. Поскольку большинство переменных являются категориальными, я мало что могу сделать.
Модель обучения и перекрестная проверка
Функция train
упрощает процесс построения и оценки модели. Моя первая модель - это k-ближайшие соседи (KNN). Если бы я использовал пакет knn
, следуя инструкции ISLR, мне пришлось бы запускать knn.cv
несколько раз и сравнивать результат, чтобы найти лучший k. С caret
мне нужна только одна команда. Сначала я ищу значение method
и доступные параметры в документации. Для knn
существует только один параметр настройки, k. У меня есть три варианта тюнинга:
- Ничего не делать: в этом случае
train
пробует 3 случайных числа для k. Хотя я говорю случайный, на самом деле это не так. Но это выходит за рамки моего проекта. tuneGrid
: Я указываю числа, которые нужно попробовать,seq(2, 20, 1)
, последовательность чисел от 2 до 20 с интервалом 1.tuneLength
: Вместо того, чтобы указывать числа, я указываю функцию, чтобы попробовать 10 разных чисел. Это может быть 1–10, 101–110 или десять четных чисел.
Параметр form
сообщает модели целевую переменную и предикторы. Похоже, outcome ~ var1 + var2
. Здесь я использовал .
как широкую карту, позволяющую модели выбирать входные переменные. Наконец, как упоминалось ранее, я использовал 10-кратную перекрестную проверку для оценки модели, которая выполняется с помощью trControl
.
# Option 1: No specification on tuning parameter m_knn <- train(form = use~., data = micsTrain, method = 'knn', trControl = ctrl) # Option 2: Try all specified parameters m_knn <- train(form = use~., data = micsTrain, method = 'knn', trControl = ctrl, # Cross-validation tuneGrid = data.frame(k = seq(2, 20, 1))) # Option 3: Try 10 random parameters m_knn <- train(form = use~., data = micsTrain, method = 'knn', trControl = ctrl, # Cross-validation tuneLength = 10)
Используя вариант 3 в качестве примера, функция train
генерирует десять чисел для параметра настройки, k, и для каждого числа выполняется 10-кратная перекрестная проверка для вычисления средней точности. Процесс может занять некоторое время. После выполнения он сравнивает среднюю точность по десяти числам и возвращает число с наивысшей производительностью в качестве последнего параметра. Как показано ниже, print(m_knn)
показывает среднюю точность для каждого k и выбирает наилучшую, 9.
k-Nearest Neighbors 49604 samples 9 predictor 2 classes: 'Using', 'Not Using' No pre-processing Resampling: Cross-Validated (10 fold) Summary of sample sizes: 44644, 44644, 44644, 44644, 44643, 44643, ... Resampling results across tuning parameters: k Accuracy Kappa 5 0.7392345 0.4790542 7 0.7409078 0.4824688 9 0.7422585 0.4852000 11 0.7404442 0.4816136 13 0.7418755 0.4844997 15 0.7408474 0.4824703 17 0.7404241 0.4816458 19 0.7401217 0.4810566 21 0.7404038 0.4816320 23 0.7398998 0.4806354 Accuracy was used to select the optimal model using the largest value. The final value used for the model was k = 9.
caret
также упрощает визуализацию результатов перекрестной проверки, вызывая plot(m_knn, main = “KNN 10-fold Cross-Validation")
Протестируйте модель
Тестировать модель очень просто: предсказать целевую переменную и оценить результат.
pred_knn <- predict(m_knn, newdata = micsTest)
Поскольку это модель классификации, я могу использовать матрицу неточностей для изучения других показателей производительности, сравнивая предсказанные классы с фактическими классами.
tbl_knn <- confusionMatrix(pred_knn, micsTest$use) tbl_knn
Помимо точности, выходные данные включают другие общие меры, такие как специфичность и чувствительность. Точность модели составляет 74,81%, что не впечатляет. Можно видеть, что ошибка типа I (ложноположительный результат) встречается гораздо чаще, чем ошибка типа II (ложноотрицательный результат), что позволяет предположить, что модели в целом лучше позволяют идентифицировать женщин, которые в настоящее время используют противозачаточные средства, чем маркировать тех, кто этого не делает.
Confusion Matrix and Statistics Reference Prediction Using Not Using Using 3800 1652 Not Using 553 2747 Accuracy : 0.7481 95% CI : (0.7388, 0.7571) No Information Rate : 0.5026 P-Value [Acc > NIR] : < 2.2e-16 Kappa : 0.4968 Mcnemar's Test P-Value : < 2.2e-16 Sensitivity : 0.8730 Specificity : 0.6245 Pos Pred Value : 0.6970 Neg Pred Value : 0.8324 Prevalence : 0.4974 Detection Rate : 0.4342 Detection Prevalence : 0.6229 Balanced Accuracy : 0.7487 'Positive' Class : Using
Вывод
В этом посте приводится пример построения моделей классификации с caret
с использованием R. caret
- отличный пакет для машинного обучения, но вначале может быть сложно ориентироваться в нем. Надеюсь, этот пример поможет тем, кто плохо знаком с caret
. Удачи в модельном бизнесе и получайте удовольствие!