Выполняйте быстрые прототипы функции потерь, чтобы в полной мере воспользоваться гибкостью XGBoost.

Мотивация

Запуск XGBoost с настраиваемыми функциями потерь может значительно повысить производительность классификации/регрессии в определенных приложениях. Возможность быстро протестировать множество различных функций потерь является ключевым моментом в исследовательских средах, критичных ко времени. Таким образом, дифференциация вручную не всегда осуществима (а иногда даже подвержена человеческим ошибкам или численной нестабильности).

Автоматическое дифференцирование позволяет нам автоматически получать производные функции с учетом ее вычисления. Он делает это, представляя нашу функцию как композицию функций с известными производными, что не требует никаких усилий со стороны разработчика.

Мы начнем с краткого введения, разъяснения нашей проблемы. Затем мы углубимся в реализацию автоматической дифференциации с помощью PyTorch и JAX и интегрируем ее с XGBoost. Наконец, мы выполним тесты во время выполнения и покажем, что JAX примерно в 10 раз быстрее, чем PyTorch для этого приложения.

Фон

Gradient Boosting — это основа алгоритмов машинного обучения. Он выводит прогнозы на основе ансамбля слабых учеников, обычно деревьев решений. Слабые ученики могут быть оптимизированы в соответствии с произвольной дифференцируемой функцией потерь, что дает нам значительную гибкость. Мы сосредоточимся на случае деревьев решений как слабых учеников — деревьях решений с градиентным усилением (GBDT).

В задачах, где нейронным сетям не хватает, например, табличных данных и небольших обучающих наборов, GBDT демонстрируют современную производительность.

XGBoost — популярная библиотека, эффективно реализующая GBDT. Он предоставляет простой интерфейс для написания пользовательских функций потерь для наших деревьев решений. Учитывая пользовательскую функцию потерь, все, что нам нужно сделать, это предоставить XGBoost расчеты ее градиента и гессиана. Давайте посмотрим, как мы можем добиться этого с помощью автоматического дифференцирования за считанные минуты.

Постановка проблемы

Мы проведем наши эксперименты с набором данных California Housing — регрессионной задачей для прогнозирования цен на жилье.

Нашей функцией потерь будет Squared Log Error (SLE):

Обратите внимание, что эта потеря наказывает заниженную оценку больше, чем завышенную оценку. Это может отражать реальное требование бизнеса при прогнозировании цен на жилье, и мы можем выполнить его, выбрав пользовательскую функцию потерь.

Давайте применим это в XGBoost.

Автоматический расчет гессиана с помощью PyTorch

Далее мы сосредоточимся на работе с PyTorch, так как он понятнее — сравнение с JAX будет позже.

Вычисление градиентов с помощью PyTorch — это знакомая рабочая нагрузка из программирования нейронных сетей. Однако нам редко требуется вычислять гессианы. К счастью, PyTorch реализовал для нас удобную функцию torch.autograd.functional.hessian. С этими техническими деталями мы можем приступить к реализации.

Во-первых, мы реализуем нашу функцию потерь:

Далее наша автоматическая дифференциация:

Соединяем их вместе:

И работает на реальных данных:

Это дает нам простую рабочую реализацию автоматического дифференцирования с помощью PyTorch. Однако этот код плохо масштабируется для больших наборов данных. Таким образом, в следующем разделе мы покажем более сложный подход, улучшающий время выполнения.

Оптимизация производительности во время выполнения с помощью JAX

Мы можем добиться значительного ускорения во время выполнения, если будем правильно использовать JAX. Давайте напишем код PyTorch сверху в JAX:

Мы меняем расчет убытков таким образом, что он использует jax.numpy (импортируется как jnp),

И используйте соответствующий синтаксис для автоматической дифференциации в JAX,

Работая с предыдущими данными, по сравнению с реализацией PyTorch (рис. 3) мы видим ускорение примерно в 2 раза:

Обратите внимание, что мы используем функцию hvp (произведение вектора Гессе) (на векторе из единиц) из JAX's Autodiff Cookbook для вычисления диагонали гессиана. Этот трюк возможен только тогда, когда гессиан диагональный (все недиагональные элементы равны нулю), что верно в нашем случае. Таким образом, мы никогда не сохраняем весь гессиан, а вычисляем его на лету, уменьшая потребление памяти.

Однако наиболее значительное ускорение связано с эффективным вычислением гессиана методом дифференцирования в прямом и обратном направлениях. Технические детали выходят за рамки этого поста, вы можете прочитать о них в JAX’s Autodiff Cookbook.

Кроме того, мы используем JIT-компиляцию JAX, чтобы еще больше сократить время выполнения, примерно в 3 раза.

Тест производительности во время выполнения

Давайте представим более тщательное сравнение производительности во время выполнения.

Отметим, что реализация PyTorch имеет квадратичную сложность времени выполнения (по количеству примеров), а реализация JAX имеет линейную сложность времени выполнения. Это огромное преимущество, которое позволяет нам использовать реализацию JAX на больших наборах данных.

Теперь давайте сравним автоматическое дифференцирование с ручным дифференцированием:

Действительно, по сравнению с этим ручное дифференцирование молниеносно (в 40 раз быстрее). Однако для сложных функций потерь или небольших наборов данных автоматическое дифференцирование все еще может быть ценным навыком в вашем наборе инструментов.

Полный код этого бенчмарка можно найти здесь:



Заключение

Мы использовали возможности автоматического дифференцирования, чтобы беспрепятственно использовать настраиваемые функции потерь в XGBoost с возможным компромиссом в производительности во время выполнения. Конечно, приведенный выше код также применим к другим популярным библиотекам повышения градиента, таким как LightGBM и CatBoost.

Мы увидели, что JAX обеспечивает существенное ускорение благодаря эффективной реализации гессиана и прозрачному использованию JIT-компиляции. Кроме того, мы наметили несколько строк кода, которые позволяют нам в общем вычислить градиенты и гессианы. А именно, наш подход можно обобщить на дополнительные рабочие нагрузки, требующие автоматической дифференциации высокого порядка.

Спасибо за прочтение! Буду рад услышать ваши мысли и комментарии 😃