с набором данных Sparkify от Udacity

Обзор

Этот пост представляет собой отчет о моем личном проекте с использованием Spark для прогнозирования оттока клиентов на основе подмножества макета набора данных, предоставленного Udacity, который называется набором данных Sparkify: данные веб-событий приложения, подобного Spotify.

Основная проблема заключается в том, как определить, собирается ли пользователь уйти, и мы должны быть в состоянии идентифицировать его, чтобы предотвратить его уход, например, путем проведения рекламных акций и т. д.

Ушедшие клиенты определяются тем, отменяют ли они свою регистрацию, и подход, который я использую для решения этой проблемы, заключается в построении модели машинного обучения, которая может точно предсказать пользователей, которые собираются уйти.

Поскольку важно, может ли модель предсказать отток клиентов, подходящей метрикой для ее оценки является ее точность (читайте здесь для обсуждения точности и оценки F1).

Вы также можете получить доступ к репозиторию GitHub, где размещен этот проект, по этой ссылке.

Предварительная обработка данных

Вот столбцы и их форматы, содержащиеся в наборе данных:

> df.printSchema()
root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)

Поскольку набор данных содержит веб-события, мне нужно преобразовать их, чтобы создать новый набор данных с точкой данных об одном уникальном пользователе с функциями и меткой, указывающей, уходит она или нет. Преобразование будет подробно описано в разделе «Обработка данных» ниже.

Что касается пропущенных значений, вот результаты их подсчета:

artist missing values total: 58392
auth missing values total: 0
firstName missing values total: 8346
gender missing values total: 8346
itemInSession missing values total: 0
lastName missing values total: 8346
length missing values total: 58392
level missing values total: 0
location missing values total: 8346
method missing values total: 0
page missing values total: 0
registration missing values total: 8346
sessionId missing values total: 0
song missing values total: 58392
status missing values total: 0
ts missing values total: 0
userAgent missing values total: 8346
userId missing values total: 0

Можно увидеть шаблон с числом 8346 пропущенных значений в столбцах имени пользователя, их пола, местоположения, даты регистрации и пользовательского агента: скорее всего, это люди, которые не подписали- на платформу. Их обязательно удалят из датасета, так как меня интересуют только точки данных с идентификатором: отточат ли.

Я удаляю их, ориентируясь на столбцы userId или sessionId: удаление их недопустимых значений также приведет к удалению пользователей, которые не вошли в систему.

При поиске недопустимых значений в userId и sessionId было обнаружено, что sessionId не имеет недопустимых значений, а userId содержит 8346 пустых строк в качестве значений:

> df.select('userId').dropDuplicates().orderBy('userId').show(10)
> print('userId empty string values:', df.filter(df.userId == '').count())
+------+
|userId|
+------+
|      |
|    10|
|   100|
|100001|
|100002|
|100003|
|100004|
|100005|
|100006|
|100007|
+------+
only showing top 10 rows

userId empty string values: 8346

Таким образом, их удаление решит проблему.

Исследование данных

Вот сравнение общего количества пользователей, которые уходят и нет:

Общее количество пользователей, содержащихся в наборе данных, составляет 225 человек, 52 из которых ушли. Это может быть слишком маленький набор данных, а несбалансированность может вызвать некоторые проблемы с отзывом и точностью: пользователи, которые не перегружены, составляют 3/4 набора данных.

Относительно пола:

Кажется, что соотношение женщин и мужчин значительно различается между теми, кто сбивается, и теми, кто этого не делает. Это говорит о том, что пол может быть хорошей характеристикой.

Относительно того, насколько они активны в течение одного сеанса:

Несмотря на то, что оба распределения пользователей имеют правильный перекос, те, кто уходят, имеют более высокую распространенность вокруг своего центрального значения (~ 50). Эта разница может оказаться полезной функцией для включения.

Относительно того, сколько дней с момента регистрации:

Те, кто уходит, как правило, делают это в течение первых 50 дней с момента регистрации. Таким образом, эти функции также могут быть полезны для включения в обучение.

Относительно того, заплатили ли они за премиальную услугу:

Соотношение бесплатного и платного между двумя группами, по-видимому, существенно различается: количество платящих людей, как правило, больше в группе, которая не уходит. Таким образом, эта функция будет включена в обучение.

Относительно средней продолжительности песни, воспроизводимой за день, неделю или месяц:

Эти функции также могут оказаться полезными для процесса обучения, поскольку их распределение кажется различным в центральных значениях или в преобладании определенных значений.

Относительно среднего количества песен, воспроизводимых за день, неделю или месяц:

Как и в случае средней продолжительности песен, воспроизводимых в указанном выше интервале, среднее количество песен также может быть хорошей характеристикой, учитывая разницу в плотности или центральном значении между распределениями по группам.

Относительно среднего количества сеансов в день, неделю или месяц:

Разница между групповыми распределениями также предполагает включение этих особенностей в обучение. Также интересно отметить, что распределение людей, которые уходят, обычно имеет более толстый верхний хвост.

Это может свидетельствовать о том, что они, как правило, исследуют платформу больше, чем те, кто не убегает, даже если они играют меньше песен: они могут чаще посещать страницу справки, получать больше ошибок, получать больше рекламы и так далее. Это было бы интересной проблемой для расширения в будущем.

Разработка функций

Прежде всего, мне нужно создать набор данных с каждым уникальным пользователем в качестве точки данных, а столбцы будут функциями и метками. Сначала я объединил каждого пользователя с функцией isChurn:

> ischurn = udf(lambda page: 1 if page == 'Cancellation Confirmation' else 0, IntegerType())
> df_user_churn = df.withColumn('churn', ischurn(df.page)).groupBy('userId').agg(sum('churn').alias('isChurn'))

Кроме того, мне нужно извлечь поведение пользователей, записанное в наборе данных, которое можно извлечь, манипулируя столбцом page. Столбец page включает следующие значения:

> df.groupBy('page').count().orderBy('page').show()
+--------------------+------+
|                page| count|
+--------------------+------+
|               About|   495|
|          Add Friend|  4277|
|     Add to Playlist|  6526|
|              Cancel|    52|
|Cancellation Conf...|    52|
|           Downgrade|  2055|
|               Error|   252|
|                Help|  1454|
|                Home| 10082|
|              Logout|  3226|
|            NextSong|228108|
|         Roll Advert|  3933|
|       Save Settings|   310|
|            Settings|  1514|
|    Submit Downgrade|    63|
|      Submit Upgrade|   159|
|         Thumbs Down|  2546|
|           Thumbs Up| 12551|
|             Upgrade|   499|
+--------------------+------+

Все они, кроме значений NextSong и Cancellation Confirmation, будут извлечены, поскольку NextSong будет отражаться в среднем количестве песен, воспроизводимых функцией интервала, а Cancellation Confirmation будет меткой.

Я преобразовал их в поведенческие характеристики, суммировав их для каждого пользователя с помощью следующего кода:

> pd_page_names = df.select('page').distinct().toPandas()
> page_names = pd_page_names.page.tolist()
> page_names.remove('Cancellation Confirmation')
> page_names.remove('NextSong')
>
> for page in page_names:
>     df_page = df.groupBy(['userId', 'page']).agg(count(col('page')).alias('sumPage' + page.replace(' ', ''))).filter('page == "' + page + '"').drop('page')
>
>     df_transformed = df_transformed.join(df_page, ['userId'], 'left').fillna(0)
>
> df_transformed.head()
Row(userId='100010', isChurn=0, sumPageCancel=0, sumPageSubmitDowngrade=0, sumPageThumbsDown=5, sumPageHome=11, sumPageDowngrade=0, sumPageRollAdvert=52, sumPageLogout=5, sumPageSaveSettings=0, sumPageAbout=1, sumPageSettings=0, sumPageAddtoPlaylist=7, sumPageAddFriend=4, sumPageThumbsUp=17, sumPageHelp=2, sumPageUpgrade=2, sumPageError=0, sumPageSubmitUpgrade=0)

И, с учетом функций, рассмотренных в разделе «Исследование данных» выше, общие разработанные функции следующие:

> df_transformed.printSchema()
root
 |-- avgItemsPerSession: double (nullable = true)
 |-- avgLengthPerDay: double (nullable = true)
 |-- avgLengthPerMonth: double (nullable = true)
 |-- avgLengthPerWeek: double (nullable = true)
 |-- avgSessionsPerDay: double (nullable = true)
 |-- avgSessionsPerMonth: double (nullable = true)
 |-- avgSessionsPerWeek: double (nullable = true)
 |-- avgSongsPerDay: double (nullable = true)
 |-- avgSongsPerMonth: double (nullable = true)
 |-- avgSongsPerWeek: double (nullable = true)
 |-- gender: long (nullable = true)
 |-- havePaid: long (nullable = true)
 |-- isChurn: long (nullable = true)
 |-- sumPageAbout: long (nullable = true)
 |-- sumPageAddFriend: long (nullable = true)
 |-- sumPageAddtoPlaylist: long (nullable = true)
 |-- sumPageCancel: long (nullable = true)
 |-- sumPageDowngrade: long (nullable = true)
 |-- sumPageError: long (nullable = true)
 |-- sumPageHelp: long (nullable = true)
 |-- sumPageHome: long (nullable = true)
 |-- sumPageLogout: long (nullable = true)
 |-- sumPageRollAdvert: long (nullable = true)
 |-- sumPageSaveSettings: long (nullable = true)
 |-- sumPageSettings: long (nullable = true)
 |-- sumPageSubmitDowngrade: long (nullable = true)
 |-- sumPageSubmitUpgrade: long (nullable = true)
 |-- sumPageThumbsDown: long (nullable = true)
 |-- sumPageThumbsUp: long (nullable = true)
 |-- sumPageUpgrade: long (nullable = true)
 |-- userId: string (nullable = true)

Что касается столбцов location и userAgent, они могут оказаться полезными в качестве функций. Однако, поскольку набор данных содержит 40 различных состояний и 56 различных пользовательских агентов, включение каждого из них в качестве функций горячего кодирования приведет к тому, что количество функций будет более чем вдвое меньше количества точек данных (225). Таким образом, эти функции в настоящее время игнорируются и зарезервированы для обработки полного набора данных в будущем.

Выполнение

С преобразованными данными, представленными выше, я собрал и масштабировал их, чтобы они были вектором в столбце features, и извлек isChurn как label. Вот строка и распечатанная схема набора данных, которые будут переданы в машинное обучение:

> df_ml.head()
Row(label=0, features=DenseVector([0.0005, 0.1022, 0.9711, 0.2158, 0.0, 0.0001, 0.0, 0.0004, 0.004, 0.0009, 0.0, 0.0, 0.0, 0.0002, 0.0002, 0.0, 0.0001, 0.0, 0.0001, 0.0006, 0.0002, 0.0007, 0.0, 0.0001, 0.0, 0.0, 0.0001, 0.0003, 0.0001]))
> df_ml.printSchema()
root
 |-- label: long (nullable = true)
 |-- features: vector (nullable = true)

Кроме того, я разделил 20% данных для тестовой группы, чтобы измерить точность каждой модели машинного обучения, обученной оставшимися 80%.

Подход, который я использовал при поиске лучшего алгоритма и параметров машинного обучения, заключался в выполнении перекрестной проверки с 3-кратным повторением следующих алгоритмов:

  • Логистическая регрессия

regParam=[0.0, 0.1, 0.2] и maxIter=[25, 50, 100]

  • Случайный лес

numTrees=[10, 20, 30] и maxDepth=[2, 5, 10]

  • Дерево с градиентным усилением

maxIter=[10, 20, 30] и maxDepth=[2, 5, 10]

  • Древо решений

impurity=['gini', 'entropy'] и maxDepth=[2, 5]

  • Наивный Байес

smoothing=[0.0, 0.5, 1.0]

  • Линейная машина опорных векторов

maxIter=[10, 20, 30] и regParam=[0.0, 0.1, 0.2]

Каждый из них будет оцениваться путем расчета точности предсказанной классификации набора тестовых данных.

Оценка и проверка модели

  • Логистическая регрессия

Результат перекрестной проверки логистической регрессии с 3 раза:

> lr_model.avgMetrics
[0.9181782116564725,
 0.921785079393775,
 0.921785079393775,
 0.9095974856844422,
 0.908388859475816,
 0.908388859475816,
 0.9002231447883622,
 0.8985575398618877,
 0.8985575398618877]

Таким образом, лучшими параметрами являются regParam=0.0 и maxIter=50 или 100. Результат точности следующий:

> lr_result = lr_model.transform(test)
> acc = lr_result.filter(lr_result.label == lr_result.prediction).count() / lr_result.count()
>
> print('Logistic Regression Accuracy:', acc)
Logistic Regression Accuracy: 0.9615384615384616
  • Случайный лес

Результат перекрестной проверки случайного леса с 3-кратным:

> rf_model.avgMetrics
[1.0, 1.0, 1.0, 0.9855072463768115, 1.0, 1.0, 1.0, 1.0, 1.0]

Таким образом, лучшими параметрами с наименьшим количеством деревьев и глубиной являются numTrees=10 и maxDepth=2. Результат точности следующий:

> rf_result = rf_model.transform(test)
> acc = rf_result.filter(rf_result.label == rf_result.prediction).count() / rf_result.count()
>
> print('Random Forest Accuracy:', acc)
Random Forest Accuracy: 1.0
  • Дерево с градиентным усилением

Результат перекрестной проверки дерева с градиентным усилением в 3 раза:

> gbt_model.avgMetrics
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

Таким образом, лучшими параметрами с наименьшим числом итераций и глубиной являются maxIter=10 и maxDepth=2. Результат точности следующий:

> gbt_result = gbt_model.transform(test)
> acc = gbt_result.filter(gbt_result.label == gbt_result.prediction).count() / gbt_result.count()
>
> print('Gradient-boosted Tree Accuracy:', acc)
Gradient-boosted Tree Accuracy: 1.0
  • Древо решений

Результат перекрестной проверки дерева решений с 3 раза:

> dt_model.avgMetrics
[1.0, 1.0, 1.0, 1.0]

Таким образом, лучшими параметрами с наименьшим числом глубины являются maxDepth=2 и impurity='gini' или entropy. Результат точности следующий:

> dt_result = dt_model.transform(test)
> acc = dt_result.filter(dt_result.label == dt_result.prediction).count() / dt_result.count()
> 
> print('Gradient-boosted Tree Accuracy:', acc)
Gradient-boosted Tree Accuracy: 1.0
  • Наивный Байес

Результат наивной байесовской перекрестной проверки с 3-кратным:

> nb_model.avgMetrics
[0.40776921211703815, 0.4048949809819375, 0.4041442302311867]

Таким образом, лучший параметр — smoothing=0.0. Результат точности следующий:

> nb_result = nb_model.transform(test)
> acc = nb_result.filter(nb_result.label == nb_result.prediction).count() / nb_result.count()
> 
> print('Naive Bayes Accuracy:', acc)
Naive Bayes Accuracy: 1.0
  • Линейная машина опорных векторов

Результат перекрестной проверки машины линейных опорных векторов с 3-кратным:

> svm_model.avgMetrics
[0.8095160975595759,
 0.7697038892691067,
 0.8004574091530613,
 0.9855072463768115,
 0.9761555033294163,
 0.9644046220133176,
 0.9879227053140096,
 0.985213474343909,
 0.9716509988249118]

Таким образом, лучшие параметры maxIter=30 и regParam=0.0. Результат точности следующий:

> svm_result = svm_model.transform(test)
> acc = svm_result.filter(svm_result.label == svm_result.prediction).count() / svm_result.count()
>
> print('Linear SVM Accuracy:', acc)
Linear SVM Accuracy: 1.0

Конечный результат

Оценка точности 100% часто показывается из приведенных выше результатов. Основываясь на этой ветке StackExchange, это может указывать на несколько проблем:

  • Произошла утечка данных: столбец с метками добавлен как функции
  • Данные теста дублируются из данных поезда.
  • Данные теста слишком малы
  • Проверка не правильная
  • Для прогнозирования оттока в этом наборе данных не требуется машинное обучение.

После нескольких проведенных тестов у меня осталось два варианта: либо данные слишком малы, либо мне не нужно машинное обучение для прогнозирования оттока в этом наборе данных.

Тем не менее было показано, что логистическая регрессия и наивный байесовский алгоритм дают более низкую точность обучения или теста. Остальные части алгоритма, насколько можно судить по результатам, так же хороши, хотя SVM показывает лучшее обобщение: 100% точность теста с точностью обучения 98,79%.

Дальнейшее улучшение

Было показано, что вывод о лучшем алгоритме машинного обучения для прогнозирования оттока в этом наборе данных все еще остается неопределенным. Дальнейшая работа над этим проектом необходима для получения окончательной модели машинного обучения и функций, которые лучше всего предсказывают отток.

Вот несколько будущих задач по улучшению проекта:

  • Проведите анализ полного набора данных, чтобы устранить проблему слишком малого набора данных.
  • Выполняйте анализ на AWS, а не на локальном компьютере, чтобы использовать возможности Spark в распределенном программировании.
  • Обработка дополнительных функций: location и userAgent
  • Найдите лучший алгоритм машинного обучения на полном наборе данных
  • Уменьшите размерность, чтобы уменьшить эффект мультиколлинеарности.