Деревья принятия решений являются одними из самых мощных инструментов машинного обучения, доступных сегодня, и используются в самых разных реальных приложениях, от Прогнозирование кликов по рекламе в Facebook ¹ до Ранжирование опыта Airbnb. Тем не менее, они интуитивно понятны, их легко интерпретировать и легко реализовать. В этой статье мы обучим наш собственный классификатор дерева решений всего на 66 строках кода Python.

Что такое дерево решений?

Деревья решений могут использоваться для регрессии (непрерывный результат с действительной стоимостью,
например прогноз цены дома) или классификации ( категориальный вывод,
например прогнозирование спама в электронной почте или отсутствие спама), но здесь мы сосредоточимся на классификации. Классификатор дерева решений - это двоичное дерево, в котором прогнозы делаются путем обхода дерева от корня к листу - на каждом узле мы идем влево, если характеристика меньше порогового значения, в противном случае вправо . Наконец, каждый лист связан с классом, который является выходом предсказателя.

Например, рассмотрите этот Набор данных беспроводной локальной локализации .² Он дает 7 характеристик, представляющих мощность 7 сигналов Wi-Fi, воспринимаемых телефоном в квартире, а также местоположение телефона внутри помещения, которое может быть в комнатах 1, 2, 3. или 4.

+-------+-------+-------+-------+-------+-------+-------+------+
| Wifi1 | Wifi2 | Wifi3 | Wifi4 | Wifi5 | Wifi6 | Wifi7 | Room |
+-------+-------+-------+-------+-------+-------+-------+------+
|  -64  |  -55  |  -63  |  -66  |  -76  |  -88  |  -83  |   1  |
|  -49  |  -52  |  -57  |  -54  |  -59  |  -85  |  -88  |   3  |
|  -36  |  -60  |  -53  |  -36  |  -63  |  -70  |  -77  |   2  |
|  -61  |  -56  |  -55  |  -63  |  -52  |  -84  |  -87  |   4  |
|  -36  |  -61  |  -57  |  -27  |  -71  |  -73  |  -70  |   2  |
                               ...

Цель состоит в том, чтобы предсказать, в какой комнате находится телефон, основываясь на силе сигналов Wi-Fi с 1 по 7. Обученное дерево решений глубины 2 может выглядеть следующим образом:

Примесь Джини

Прежде чем мы углубимся в код, давайте определим метрику, используемую во всем алгоритме. В деревьях решений используется концепция примеси Джини, чтобы описать, насколько однородным или «чистым» является узел. Узел является чистым (G = 0), если все его образцы принадлежат одному классу, тогда как узел с множеством образцов из разных классов будет иметь индекс Джини, близкий к 1.

Более формально примесь Джини n обучающих выборок, разделенных на k классов, определяется как

где p [k] - доля выборок, принадлежащих классу k.

Например, если узел содержит пять образцов, два из которых относятся к классу 1, два - к классу 2, один - к классу 3 и ни один - к классу 4, тогда

CART алгоритм

Алгоритм обучения представляет собой рекурсивный алгоритм, называемый CART, сокращенно от Classification And Regression Trees .³ Каждый узел разбивается таким образом, чтобы примесь Джини дочерних элементов (точнее, среднее из коэффициент Джини детей, взвешенный по их размеру) минимизирован.

Рекурсия останавливается, когда достигается максимальная глубина, гиперпараметр , или когда никакое разделение не может привести к появлению двух дочерних элементов более чистых, чем их родитель. Другие гиперпараметры могут управлять этим критерием остановки (критически важным на практике, чтобы избежать переобучения), но мы не будем рассматривать их здесь.

Например, если X = [[1.5], [1.7], [2.3], [2.7], [2.7]] и y = [1, 1, 2, 2, 3], тогда оптимальным разделением будет feature_0 < 2, потому что, как вычислено выше, Джини родительского элемента составляет 0,64, а Джини дочерних элементов после разделения равно

Вы можете убедить себя, что ни один другой сплит не дает более низкого индекса Джини.

Поиск оптимальной характеристики и порога

Ключом к алгоритму CART является поиск оптимальной характеристики и порога, которые минимизируют примесь Джини. Для этого мы пробуем все возможные расщепления и вычисляем результирующие примеси Джини.

Но как мы можем попробовать все возможные пороговые значения для непрерывных значений? Есть простой трюк - отсортировать значения для данной функции и рассмотреть все средние точки между двумя соседними значениями. Сортировка дорогостоящая, но она все равно необходима, как мы вскоре увидим.

Теперь, как мы можем вычислить Джини всех возможных разбиений?

Первое решение - фактически выполнить каждое разбиение и вычислить результирующий коэффициент Джини. К сожалению, это медленно, так как нам нужно будет просмотреть все образцы, чтобы разделить их на левую и правую. Точнее, это будет n разделений с O (n) операциями для каждого разделения, что сделает общую операцию O (n²).

Более быстрый подход - 1. перебираем отсортированные значения функций как возможные пороговые значения, 2. отслеживайте количество образцов для каждого класса слева и справа, а также 3. увеличивают / уменьшают их на 1 после каждого порога. По ним мы можем легко вычислить Джини за постоянное время.

Действительно, если m - это размер узла, а m [k] - количество выборок класса k в узле, тогда

и поскольку после просмотра i -го порога слева появляются элементы i, а справа m – i,

а также

Полученный индекс Джини представляет собой простое средневзвешенное значение:

Вот весь метод _best_split.

Условие в строке 61 - это последняя тонкость. Перебирая все значения функций в цикле, мы разрешаем разбиение на выборки с одинаковым значением. На самом деле мы можем разделить их только в том случае, если они имеют различное значение для этой функции, отсюда и дополнительная проверка.

Рекурсия

Самое сложное сделано! Теперь все, что нам нужно сделать, это рекурсивно разделить каждый узел, пока не будет достигнута максимальная глубина.

Но сначала давайте определим Node класс:

Подгонка дерева решений к данным X и целям y выполняется с помощью метода fit(), который вызывает рекурсивный метод _grow_tree():

Прогнозы

Мы увидели, как подогнать дерево решений, а теперь как мы можем использовать его для прогнозирования классов для невидимых данных? Что может быть проще - идите влево, если значение функции ниже порогового значения, в противном случае идите вправо.

Обучите модель

Наш DecisionTreeClassifier готов! Давайте обучим модель на наборе данных беспроводной локальной сети:

В качестве проверки работоспособности вот результат реализации Scikit-Learn:

Сложность

Легко видеть, что прогноз выражается в O (log m), где m - глубина дерева.

А как насчет тренировок? Здесь вам пригодится Основная теорема. Временная сложность подгонки дерева к набору данных с помощью n выборок может быть выражена следующим соотношением повторения:

где в оптимальном случае, когда левый и правый дочерние элементы имеют одинаковый размер, a = 2 и b = 2; и f (n) - сложность разделения узла на двух дочерних узлов, другими словами сложность _best_split. Первый цикл for выполняет итерацию по функциям, и для каждой итерации существует сортировка сложности O (n log n) и еще один цикл for. через O (n). Следовательно, f (n) равно O (k n log n), где k - количество функций.

С этими предположениями основная теорема говорит нам, что общая временная сложность равна

Это не так уж далеко, но все же хуже, чем сложность реализации Scikit-Learn, очевидно, в O (kn log n). Если кто-то знает, как это возможно, дайте мне знать в комментарии!

Полный код

Полный код можно найти в этом репозитории Github. И ради интереса вот, как и было обещано, версия урезанная до 66 строк.

¹ Синьран Хэ, Цзюньфэн Пан, Оу Цзинь, Тяньбин Сюй, Бо Лю, Тао Сюй, Янсинь Ши, Антуан Аталлах, Ральф Хербрих, Стюарт Бауэрс и Хоакин Киньонеро Кандела. 2014. Практические уроки прогнозирования кликов по рекламе в Facebook. В материалах восьмого международного семинара по интеллектуальному анализу данных для интернет-рекламы (ADKDD’14). ACM, Нью-Йорк, Нью-Йорк, США, Статья 5, 9 страниц. DOI = http: //dx.doi.org/10.1145/2648584.2648589

² Джаянт Дж. Рора, Боминатан Перумал, Свати Джамджала Нараянан, Прия Тхакур и Раджен Б. Бхатт, «Локализация пользователя во внутренней среде с использованием нечеткого гибридного алгоритма оптимизации роя частиц и гравитационного поиска с нейронными сетями», в материалах Шестой Международной конференции по Мягкие вычисления для решения проблем, 2017, стр. 286–295.

³ Брейман, Лео; Friedman, J. H .; Ольшен, Р. А .; Стоун, К. Дж. (1984). Деревья классификации и регрессии. Монтерей, Калифорния: Уодсворт и Брукс / Продвинутые книги и программное обеспечение Коула.