Кодирование мощного алгоритма на Python с использованием (в основном) массивов и циклов

Эта статья призвана демистифицировать популярный алгоритм случайного леса (здесь и по всему тексту - RF) и показать его принципы с помощью графиков, фрагментов кода и выходных данных.

Полная реализация написанного мной алгоритма RF на python доступна по адресу: https://github.com/Eligijus112/decision-tree-python

Я настоятельно рекомендую всем, кто наткнулся на эту статью, углубиться в код, потому что понимание кода сделает любое дальнейшее чтение документации по RF гораздо более простым и менее напряженным.

Любые предложения по оптимизации настоятельно приветствуются и приветствуются через пулреквест на GitHub.

Строительными блоками RF являются простые деревья решений. Эту статью будет намного легче читать, если читатель знаком с концепцией дерева решений по классификации. Настоятельно рекомендуется прочитать следующую статью, прежде чем продолжить:



Scikit-learn на Python имеет реализацию RF-алгоритма, которая работает быстро и сотни раз проверяется:



В документации первым гиперпараметром, который необходимо определить, является параметр n_estimators со значением по умолчанию 100. Описание этого параметра очень элегантно:

Количество деревьев в лесу.

Это объяснение очень романтично, но упрямый математик, читающий впервые, может найти это немного озадачивающим: о каком лесу и о каких деревьях идет речь?

Интуиция, лежащая в основе алгоритма случайного леса, может быть разделена на две большие части: часть random и часть forest. . Начнем с последнего.

В реальной жизни лес состоит из группы деревьев. Классификатор случайного леса состоит из набора классификаторов дерева решений (здесь и по всему тексту - DT). Точное количество DT, составляющих весь лес, определяется с помощью упомянутой ранее переменной n_estimators.

Каждый DT в алгоритме RF полностью независим друг от друга. Если для переменной n_estimators задано значение 50, то лес состоит из 50 деревьев решений, которые были выращены полностью независимо друг от друга и не разделяют никакой информации.

Дерево решений двоичной классификации можно рассматривать как функцию, которая принимает входные данные X и выдает либо 1, либо 0:

DT: X → {0, 1}

Окончательный прогноз RF - это большинство прогнозов, сделанных для каждого отдельного DT.

Если из 50 деревьев 30 деревьев помечают новое наблюдение как 1, а 20 деревьев маркируют то же наблюдение как 0 окончательное предсказание случайного леса будет 1.

Из статьи о простых классификационных деревьях решений ясно, что с одними и теми же входными данными и теми же гиперпараметрами, один и тот же выход и одни и те же правила будут изучены деревом решений. Так зачем выращивать 50 (или 100, или 1000, или k) из них? Здесь вступает в действие вторая часть интуиции RF: часть случайного.

Случайную часть в RF можно разделить на две части.

Первая часть - это загрузка данных. Загрузка данных - это подвыборка его строк с заменой. Часть реализации python, которая создает загрузочный образец:

Например, если весь набор данных d состоял из 10 строк:

Переменная ответа - это столбец Churn, а три других столбца - это характеристики.

Три независимых набора данных начальной загрузки d могут выглядеть так:

Как видно из рисунка выше, в первом примере 5-я и 8-я строки совпадают. Кроме того, присутствуют не все оригинальные наблюдения.

Если мы определим, что RF имеет k деревьев решений (n_estimators = k), тогда будет создано k различных наборов данных начальной загрузки, и каждое дерево будет расти с другим набором данных. Каждый набор данных может иметь такое же количество строк, что и исходный, или может иметь меньше строк, чем исходный набор данных.

Таким образом, если RF состоит из 50 деревьев решений, то высокоуровневый граф RF будет выглядеть следующим образом:

Каждое из пятидесяти деревьев решений будет выращено с уникальным набором данных.

Вторая часть random в случайном лесу находится в фазе обучения. При подгонке дерева решений по классификации к данным во время каждого выбора разделения функции остаются неизменными.

При создании классификатора случайного леса каждое дерево решений растет немного иначе, чем в реализованном здесь алгоритме.

Https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

Единственное отличие состоит в том, что во время каждого из моментов создания разбиения рисуется случайная подвыборка функций i, и для этих функций рассчитывается максимальное усиление Джини.

Https://gist.github.com/Eligijus112/4c47be2f7566299bb8c4f97c107d82c6f

Например, предположим, что у нас есть загруженная матрица d и X выборки данных со 100 начальными характеристиками. Мы можем определить, сколько функций оставить на каждой фазе разделения с аргументом X_features_fraction (max_features в реализации scikit-learn). Давайте установим его на 0,05, что означает, что при каждом разбиении для разбиения будут выбираться 5 случайных характеристик (или 5%).

В первом узле есть пять признаков X: 1, 25, 28, 30 и 98. Наилучшее разбиение достигается при разбиении 25-го признака по значению x. Два нижеприведенных узла имеют еще 5 случайных функций, в которых ищется лучший разбиение.

Таким образом, алгоритм RF выглядит следующим образом:

Выберите набор гиперпараметров.

Для деревьев от 1 до k:

  • Создайте случайную выборку данных начальной загрузки d.
  • Подберите дерево решений к данным d. Во время каждого разделения на этапе роста выберите случайную подвыборку объектов для разделения.

При прогнозировании получите прогноз для каждого из k деревьев, а затем выберите метку большинством голосов.

Подгоним классификатор RF к некоторым реальным данным. Полную записную книжку и данные можно получить через:

Https://github.com/Eligijus112/decision-tree-python/blob/main/RandomForestShowcase.ipynb

Набор данных, который находится в моем репозитории GitHub, имеет 3333 строки, и мы будем использовать следующие столбцы:

Цель состоит в том, чтобы создать классификатор, который предсказывает, откажется от услуг телекоммуникационный клиент. Отток - это событие, когда покупатель больше не использует продукт компании. Если отток = 1, это означает, что пользователь переключился на другого поставщика услуг или по какой-либо другой причине не проявляет активности.

Давайте разделим данные на обучающую и тестовую части. В обучающих данных будет 75% строк (2500), а в тестовом наборе будет 25% строк (833).

Выращивая 5 деревьев решений с 25% функций в каждом разбиении и каждое дерево с глубиной 2, мы получаем следующий лес:

------ 

Tree number: 1 

Root
   | GINI impurity of the node: 0.27
   | Class distribution in the node: {0: 2099, 1: 401}
   | Predicted class: 0
|-------- Split rule: DayCalls <= 113.5
           | GINI impurity of the node: 0.24
           | Class distribution in the node: {0: 1598, 1: 257}
           | Predicted class: 0
|---------------- Split rule: DayMins <= 287.7
                   | GINI impurity of the node: 0.21
                   | Class distribution in the node: {0: 1591, 1: 218}
                   | Predicted class: 0
|---------------- Split rule: DayMins > 287.7
                   | GINI impurity of the node: 0.26
                   | Class distribution in the node: {1: 39, 0: 7}
                   | Predicted class: 1
|-------- Split rule: DayCalls > 113.5
           | GINI impurity of the node: 0.35
           | Class distribution in the node: {0: 501, 1: 144}
           | Predicted class: 0
|---------------- Split rule: DayMins <= 225.0
                   | GINI impurity of the node: 0.25
                   | Class distribution in the node: {0: 431, 1: 76}
                   | Predicted class: 0
|---------------- Split rule: DayMins > 225.0
                   | GINI impurity of the node: 0.5
                   | Class distribution in the node: {1: 68, 0: 70}
                   | Predicted class: 0
------ 

------ 

Tree number: 2 

Root
   | GINI impurity of the node: 0.26
   | Class distribution in the node: {0: 2124, 1: 376}
   | Predicted class: 0
|-------- Split rule: OverageFee <= 13.235
           | GINI impurity of the node: 0.24
           | Class distribution in the node: {0: 1921, 1: 307}
           | Predicted class: 0
|---------------- Split rule: DayMins <= 261.45
                   | GINI impurity of the node: 0.18
                   | Class distribution in the node: {0: 1853, 1: 210}
                   | Predicted class: 0
|---------------- Split rule: DayMins > 261.45
                   | GINI impurity of the node: 0.48
                   | Class distribution in the node: {0: 68, 1: 97}
                   | Predicted class: 1
|-------- Split rule: OverageFee > 13.235
           | GINI impurity of the node: 0.38
           | Class distribution in the node: {1: 69, 0: 203}
           | Predicted class: 0
|---------------- Split rule: DayMins <= 220.35
                   | GINI impurity of the node: 0.13
                   | Class distribution in the node: {0: 186, 1: 14}
                   | Predicted class: 0
|---------------- Split rule: DayMins > 220.35
                   | GINI impurity of the node: 0.36
                   | Class distribution in the node: {1: 55, 0: 17}
                   | Predicted class: 1
------ 

------ 

Tree number: 3 

Root
   | GINI impurity of the node: 0.25
   | Class distribution in the node: {1: 366, 0: 2134}
   | Predicted class: 0
|-------- Split rule: DataUsage <= 0.315
           | GINI impurity of the node: 0.29
           | Class distribution in the node: {1: 286, 0: 1364}
           | Predicted class: 0
|---------------- Split rule: MonthlyCharge <= 62.05
                   | GINI impurity of the node: 0.18
                   | Class distribution in the node: {1: 144, 0: 1340}
                   | Predicted class: 0
|---------------- Split rule: MonthlyCharge > 62.05
                   | GINI impurity of the node: 0.25
                   | Class distribution in the node: {1: 142, 0: 24}
                   | Predicted class: 1
|-------- Split rule: DataUsage > 0.315
           | GINI impurity of the node: 0.17
           | Class distribution in the node: {0: 770, 1: 80}
           | Predicted class: 0
|---------------- Split rule: RoamMins <= 13.45
                   | GINI impurity of the node: 0.12
                   | Class distribution in the node: {0: 701, 1: 49}
                   | Predicted class: 0
|---------------- Split rule: RoamMins > 13.45
                   | GINI impurity of the node: 0.43
                   | Class distribution in the node: {0: 69, 1: 31}
                   | Predicted class: 0
------ 

------ 

Tree number: 4 

Root
   | GINI impurity of the node: 0.26
   | Class distribution in the node: {0: 2119, 1: 381}
   | Predicted class: 0
|-------- Split rule: DayCalls <= 49.5
           | GINI impurity of the node: 0.49
           | Class distribution in the node: {1: 8, 0: 6}
           | Predicted class: 1
|---------------- Split rule: MonthlyCharge <= 31.5
                   | GINI impurity of the node: 0.0
                   | Class distribution in the node: {0: 4}
                   | Predicted class: 0
|---------------- Split rule: MonthlyCharge > 31.5
                   | GINI impurity of the node: 0.32
                   | Class distribution in the node: {1: 8, 0: 2}
                   | Predicted class: 1
|-------- Split rule: DayCalls > 49.5
           | GINI impurity of the node: 0.26
           | Class distribution in the node: {0: 2113, 1: 373}
           | Predicted class: 0
|---------------- Split rule: DayMins <= 264.6
                   | GINI impurity of the node: 0.21
                   | Class distribution in the node: {0: 2053, 1: 279}
                   | Predicted class: 0
|---------------- Split rule: DayMins > 264.6
                   | GINI impurity of the node: 0.48
                   | Class distribution in the node: {1: 94, 0: 60}
                   | Predicted class: 1
------ 

------ 

Tree number: 5 

Root
   | GINI impurity of the node: 0.24
   | Class distribution in the node: {0: 2155, 1: 345}
   | Predicted class: 0
|-------- Split rule: OverageFee <= 7.945
           | GINI impurity of the node: 0.15
           | Class distribution in the node: {0: 475, 1: 43}
           | Predicted class: 0
|---------------- Split rule: AccountWeeks <= 7.0
                   | GINI impurity of the node: 0.28
                   | Class distribution in the node: {1: 5, 0: 1}
                   | Predicted class: 1
|---------------- Split rule: AccountWeeks > 7.0
                   | GINI impurity of the node: 0.14
                   | Class distribution in the node: {0: 474, 1: 38}
                   | Predicted class: 0
|-------- Split rule: OverageFee > 7.945
           | GINI impurity of the node: 0.26
           | Class distribution in the node: {0: 1680, 1: 302}
           | Predicted class: 0
|---------------- Split rule: DayMins <= 259.9
                   | GINI impurity of the node: 0.2
                   | Class distribution in the node: {0: 1614, 1: 203}
                   | Predicted class: 0
|---------------- Split rule: DayMins > 259.9
                   | GINI impurity of the node: 0.48
                   | Class distribution in the node: {0: 66, 1: 99}
                   | Predicted class: 1
------

Каждое дерево немного отличается друг от друга. Оценки точности и запоминания на тестовом наборе с выбранными гиперпараметрами:

Попробуем повысить метрики точности, вырастив более сложный случайный лес.

При выращивании случайного леса с 30 деревьями, 50% функций в каждом разбиении и max_depth из 4 точность в тестовом наборе составляет:

Если мы вырастим 100 деревьев, с 75% характеристик и максимальной глубиной деревьев 5, результаты будут следующими:

Реализация scikit learn дает очень похожие результаты:

# Skicit learn implementation
from sklearn.ensemble import RandomForestClassifier
# Initiating
rf_scikit = RandomForestClassifier(n_estimators=100, max_features=0.75, max_depth=5, min_samples_split=5)
# Fitting 
start = time.time()
rf_scikit.fit(X=train[features], y=train[‘Churn’])
print(f”Time took for scikit learn: {round(time.time() — start, 2)} seconds”)
# Forecasting 
yhatsc = rf_scikit.predict(test[features])
test[‘yhatsc’] = yhatsc
print(f”Total churns in test set: {test[‘Churn’].sum()}”)
print(f”Total predicted churns in test set: {test[‘yhat’].sum()}”)
print(f”Precision: {round(precision_score(test[‘Churn’], test[‘yhatsc’]), 2) * 100} %”)
print(f”Recall: {round(recall_score(test[‘Churn’], test[‘yhatsc’]), 2) * 100} %”)

Обратите внимание, что разница между результатами склеарна и заказного кода может различаться из-за случайного характера данных начальной загрузки для каждого дерева и случайных функций, которые используются для поиска наилучшего разделения. Даже запустив одну и ту же реализацию с одними и теми же гиперпараметрами, мы каждый раз будем получать немного разные результаты.

Таким образом, алгоритм случайного леса состоит из независимых простых деревьев решений.

Каждое дерево решений создается с использованием настраиваемого загружаемого набора данных. В каждом разбиении и в каждом дереве решений рассматривается случайная подвыборка функций при поиске лучшей функции и значения функции, которая увеличивает усиление GINI.

Окончательное предсказание классификатора RF - это большинство голосов всех независимых деревьев решений.