Регрессионный анализ и контрфактические объяснения

InterpretML — это интерпретируемая библиотека машинного обучения, разработанная Microsoft с целью сделать модели машинного обучения более понятными и открытыми для интерпретации человеком.

Это имеет особое значение при передаче результатов заинтересованным сторонам бизнеса, которые во многих случаях не являются техническими специалистами и стремятся понять последствия для бизнеса результатов, полученных с помощью модели машинного обучения.

Цель этой статьи — проиллюстрировать, как интерпретируемое машинное обучение и контрфактический анализ могут позволить лучше понять основные тенденции в наборе данных, а также способы, с помощью которых InterpretML может интуитивно сообщать такие результаты.

В этом примере используется набор данных Комиссия по такси и лимузинам Нью-Йорка — записи о поездках на желтом такси. Этот набор данных был получен с использованием Открытых данных Azure, которые, в свою очередь, были получены с веб-сайта nyc.gov и регулируются Условиями использования nyc.gov. Набор данных предоставляется NYC Open Data, которая делает свои данные доступными в соответствии с лицензией CC0: Public Domain, как указано в учетной записи компании Kaggle.

Обратите внимание, что для проведения приведенного ниже анализа использовался Python 3.8.0.

Набор данных и предварительная обработка

Вышеупомянутый набор данных содержит точки данных о многочисленных аспектах поездок в желтом такси по Нью-Йорку, включая общую сумму взимаемой платы, расстояние поездки, сумму чаевых и сумму дорожных сборов.

Набор данных предоставляет множество других переменных, таких как время и место посадки и высадки, количество пассажиров и типы тарифов.

Однако в целях определения основных факторов, влияющих на общую сумму оплаты (которая будет зависимой переменной для приведенного ниже анализа) — расстояние поездки, сумма чаевых и сумма дорожных сборов выбраны в качестве независимых (характеристических) переменных для этого анализа.

Для этого анализа для целей моделирования использовались данные за один месяц (с 6 мая 2018 г. по 6 июня 2018 г.).

import numpy as np
from azureml.opendatasets import NycTlcYellow
from datetime import datetime
from dateutil import parser

end_date = parser.parse('2018-06-06')
start_date = parser.parse('2018-05-06')

nyc_tlc = NycTlcYellow(start_date=start_date, end_date=end_date)
nyc_tlc_df = nyc_tlc.to_pandas_dataframe()
nyc_tlc_df

Для анализа было получено более 9 миллионов строк данных — вот фрагмент данных:

>>> nyc_tlc_df
       vendorID  tpepPickupDateTime tpepDropoffDateTime  passengerCount  tripDistance puLocationId  ... extra  mtaTax  improvementSurcharge  tipAmount  tollsAmount  totalAmount
0             2 2018-05-27 17:50:34 2018-05-27 17:56:41               3          0.82          161  ...   0.0     0.5                   0.3       0.00          0.0         6.80
1             2 2018-05-23 08:20:41 2018-05-23 08:37:06               1          1.69          142  ...   0.0     0.5                   0.3       3.08          0.0        15.38
3             2 2018-05-23 09:02:54 2018-05-23 09:17:59               2          6.64          140  ...   0.0     0.5                   0.3       0.00          0.0        20.30
5             2 2018-05-23 13:28:48 2018-05-23 13:35:15               1          0.61          170  ...   0.0     0.5                   0.3       1.00          0.0         7.80
7             2 2018-05-23 07:05:50 2018-05-23 07:07:40               2          0.48           48  ...   0.0     0.5                   0.3       0.00          0.0         4.30
...         ...                 ...                 ...             ...           ...          ...  ...   ...     ...                   ...        ...          ...          ...
339945        2 2018-06-04 14:03:37 2018-06-04 14:17:11               1          1.95          262  ...   0.0     0.5                   0.3       2.00          0.0        13.30
339946        2 2018-06-04 17:15:23 2018-06-04 17:16:38               1          0.55          262  ...   1.0     0.5                   0.3       0.00          0.0         5.30
339947        2 2018-06-04 16:59:23 2018-06-04 18:24:02               6         16.95           88  ...   1.0     0.5                   0.3       0.00          0.0        62.30
339948        2 2018-06-04 10:34:44 2018-06-04 10:40:46               1          1.16          229  ...   0.0     0.5                   0.3       0.00          0.0         6.80
339949        1 2018-06-04 12:35:57 2018-06-04 12:58:32               1          2.80          231  ...   0.0     0.5                   0.3       0.00          0.0        17.30

[9066744 rows x 21 columns]

Стоит отметить, что набор данных в необработанном формате содержит некоторые ложные данные, с которыми необходимо разобраться на этапе предварительной обработки. Например, просмотр описательной статистики по интересующим переменным показывает, что такие переменные, как tipAmount, содержат отрицательные значения, когда очевидно, что оплата «отрицательных чаевых» невозможна.

>>> nyc_tlc_df['totalAmount'].describe()
count    9.066744e+06
mean     1.676839e+01
std      1.502198e+01
min     -4.003000e+02
25%      8.750000e+00
50%      1.209000e+01
75%      1.830000e+01
max      8.019600e+03
Name: totalAmount, dtype: float64

>>> nyc_tlc_df['tipAmount'].describe()
count    9.066744e+06
mean     1.912497e+00
std      2.658866e+00
min     -1.010000e+02
25%      0.000000e+00
50%      1.410000e+00
75%      2.460000e+00
max      4.000000e+02
Name: tipAmount, dtype: float64

>>> nyc_tlc_df['tollsAmount'].describe()
count    9.066744e+06
mean     3.693462e-01
std      1.883414e+00
min     -1.800000e+01
25%      0.000000e+00
50%      0.000000e+00
75%      0.000000e+00
max      1.650000e+03
Name: tollsAmount, dtype: float64

>>> nyc_tlc_df['tripDistance'].describe()
count    9.066744e+06
mean     3.022766e+00
std      3.905009e+00
min      0.000000e+00
25%      1.000000e+00
50%      1.650000e+00
75%      3.100000e+00
max      9.108000e+02
Name: tripDistance, dtype: float64

Чтобы решить эту проблему, отрицательные значения были заменены значением 0 для интересующих переменных:

y=nyc_tlc_df['totalAmount']
y[y < 0] = 0

tripDistance=nyc_tlc_df['tripDistance']
tripDistance[tripDistance < 0] = 0

tipAmount=nyc_tlc_df['tipAmount']
tipAmount[tipAmount < 0] = 0

tollsAmount=nyc_tlc_df['tollsAmount']
tollsAmount[tollsAmount < 0] = 0

Теперь проверка минимального значения для каждой из этих переменных дает минимум 0 — это то, что нам нужно.

>>> np.min(tollsAmount)
0.0
>>> np.min(tipAmount)
0.0
>>> np.min(tripDistance)
0.0
>>> np.min(y)
0.0

Обратите внимание, что необработанный набор данных чрезвычайно велик по размеру — с 1,5 млрд строк по состоянию на 2018 год — более 50 ГБ. Кроме того, данные для этого набора данных собираются с 2009 года. В этом отношении 9 миллионов строк данных, анализируемых в этом случае, по-прежнему являются лишь верхушкой айсберга.

Например, структура трафика в зимние месяцы может сильно отличаться от той, что была в мае и июне, и нет никакой гарантии, что результаты, полученные в этом конкретном месяце, обязательно будут перенесены на другие периоды времени.

При этом цель в данном случае — использовать InterpretML для лучшего понимания данных, что должно быть достижимо с учетом сегмента данных, который анализируется в этом случае.

InterpretML: регрессионный анализ

Для проведения анализа импортируются соответствующие библиотеки и выполняется разбиение набора данных на поезд-тест:

from interpret.glassbox import LinearRegression
from interpret import show
from sklearn.model_selection import train_test_split
seed = 1

X = np.column_stack((tripDistance, tipAmount, tollsAmount))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)
X_train
y_train

Теперь по обучающим данным выполняется регрессионный анализ:

lr = LinearRegression(random_state=seed)
lr
lr.fit(X_train, y_train)
lr_global = lr.explain_global()
show(lr_global)
lr_local = lr.explain_local(X_test[:5], y_test[:5])
show(lr_local)

Из приведенного выше кода видно, что модель генерирует как глобальные, так и локальные объяснения.

Согласно техническому документу Microsoft InterpretML: набор инструментов для понимания моделей машинного обучения:

  • Общие пояснения позволяют пользователю лучше понять целостное поведение модели.
  • Локальные пояснения позволяют лучше понять отдельные прогнозы.

В приведенной выше модели глобальные объяснения иллюстрируют общую связь между функциями и зависимой переменной, а локальные объяснения иллюстрируют отношения между пятью отдельными наблюдениями в тестовом наборе.

Глядя на приведенный выше график для глобальных пояснений, мы видим, что feature_0001 (или tripDistance) оценивается как самая важная функция, за которой следуют tipAmount и tollsAmount.

Давайте посмотрим на локальные объяснения пяти наблюдений в тестовом наборе. Переменная y_test[:5] содержит следующие значения:

>>> y_test[:5]
451746     9.95
161571    15.30
72007     20.16
115597    21.36
37697     22.77
Name: totalAmount, dtype: float64

Создается график, аналогичный предыдущему, но с целью прогнозирования каждого значения, как указано выше.

Например, прогнозируется значение 12,2 для фактического значения y_test, равного 9,95, при этом наибольшая важность придается функции tripDistance с tipAmount в качестве второстепенной важной функции.

Однако при прогнозировании значения 13,3 для фактического значения y_test, равного 15,3, мы видим, что только tripDistance оценивается как имеющий значение — две другие функции не учитываются:

Вот график предсказанного значения 22,8 против фактического значения 19,1.

Контрфактические объяснения

Цель контрфактуальных объяснений состоит в том, чтобы оценить, как можно ожидать, что изменение определенных метрик характеристик повлияет на переменную результата.

Для целей этого примера переменная totalAmount преобразуется в категориальную — любое значение totalAmount выше 10 долларов США считается существенным тарифом, и ему присваивается значение 1. Любое значение totalAmount ниже 10 долларов США считается низким тарифом, и ему присваивается значение 0.

Вот вопрос, на который мы хотим ответить:

Какие изменения в расстоянии поездки и сумме чаевых приведут к изменению значения 1 на 0 и наоборот?

Для проведения анализа используется библиотека dice_ml — с заданными непрерывными функциями и переменной результата:

import dice_ml
from dice_ml.utils import helpers  # helper functions
d = dice_ml.Data(dataframe=nyc_tlc_df, continuous_features=['tripDistance', 'tipAmount'], outcome_name='totalAmount')

Данные разбиваются на обучающие и тестовые наборы с определенными числовыми и категориальными переменными, и создаются конвейеры предварительной обработки для числовых и категориальных данных, а затем модель обучается с использованием классификатора случайного леса. Полный код для этого доступен в репозитории DiCE.

Вот контрфактические результаты, предоставленные dice-ml:

>>> # generate counterfactuals
>>> dice_exp_random.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 1)
   tripDistance  tipAmount  totalAmount
0           1.8       1.85            1

Diverse Counterfactual set (new outcome: 0.0)
  tripDistance tipAmount totalAmount
0          0.4       1.0         0.0
1          0.3       1.0         0.0
Query instance (original outcome : 1)
   tripDistance  tipAmount  totalAmount
0           2.3        2.0            1

Diverse Counterfactual set (new outcome: 0.0)
  tripDistance tipAmount totalAmount
0          1.0         -         0.0
1          0.6         -         0.0
2          0.1         -         0.0
3          0.4         -         0.0

Глядя на эти контрфактические результаты, можно сделать несколько интересных выводов.

Во-первых, мы можем видеть, что когда переменная tipAmount была выше 1, переменная totalAmount также показывает значение 1, т. е. указывающее, что общая взимаемая сумма превышает 10 долларов.

При рассмотрении случаев, когда переменная результата изменяется на 0 (стоимость проезда менее 10 долларов США), мы видим, что переменная tipAmount не превышает 1,0, а расстояние поездки меньше 1,0 во всех случаях.

Это может указывать на то, что чаевые с большей вероятностью будут выплачиваться в поездках на большие расстояния. Это также может указывать на то, что чаевые вносят существенный вклад в общую сумму, полученную водителем такси за эту конкретную поездку — могут быть случаи, когда поездка может быть короткой по расстоянию, но оплата чаевых дает более высокую общую сумму, чем может быть. выплачиваться при поездках, где расстояние больше, но чаевые не выплачиваются.

Заключение

В этой статье мы рассмотрели:

  • Способы анализа и предварительной обработки больших наборов данных
  • Как использовать InterpretML для проведения регрессионного анализа
  • Разница между глобальными и локальными объяснениями в модели InterpretML
  • Использование DICE-ML для создания контрфактических объяснений и идей, которые могут быть получены с помощью этой техники.

Если вы хотите, вы также можете попробовать запустить приведенные выше модели в разные периоды времени для набора данных и посмотреть, что у вас получится. Надеюсь, вам понравилась эта статья, и вы будете признательны за любые вопросы или отзывы!

Рекомендации

Отказ от ответственности. Эта статья написана на условиях «как есть» и без каких-либо гарантий. Он был написан с целью дать обзор концепций науки о данных и не должен интерпретироваться как профессиональный совет. Выводы и интерпретации в этой статье принадлежат автору и не поддерживаются и не связаны с какой-либо третьей стороной, упомянутой в этой статье. Автор не имеет отношений с третьими лицами, упомянутыми в этой статье.