Несколько недель назад я опубликовал статью, описывающую, как можно сохранить модель XGBoost, созданную в Python, в виде файла.

Python отлично подходит для создания моделей машинного обучения, но установка тяжелых библиотек машинного обучения в системах с ограниченными ресурсами может оказаться невозможной. Вам также может понадобиться работать с существующими стеками и средами, в которых модель может быть загружена на другом языке.

Онлайн-документация XGBoost содержит инструкции по установке необходимых библиотек с помощью Maven, однако здесь гораздо проще использовать Gradle. Выполните следующие действия:

Шаг 1: Базовые установки

Убедитесь, что у вас установлены Java и IntelliJ. Мы используем версию Amazon Corretto 17.

Шаг 2. Создайте новый проект со следующей конфигурацией проекта.

  1. Язык: «Java» (дух)
  2. Система сборки: «IntelliJ»
  3. JDK: «Амазонка Корретто 17»
  4. Gradle DSL: «Котлин»

Шаг 3. Измените gradle.properties.

В корневом каталоге вашего проекта должен быть файл gradle.properties. Если его еще нет, создайте его. Мы будем использовать этот файл для указания необходимых версий нашей библиотеки. По состоянию на сентябрь 2022 года работали следующие версии:

Шаг 4. Измените build.gradle.kts

Объявите переменные scalarBinaryVersion и xgboostVersion в файле и добавьте их в раздел зависимостей. Файл должен выглядеть так:

Шаг 5a: загрузите Gradle, если требуется

Вы можете увидеть всплывающее окно в правом нижнем углу проекта с просьбой загрузить Gradle. Нажмите на нее, чтобы IntelliJ загрузил и установил Gradle. Если всплывающее окно не отображается, щелкните вкладку «Уведомления» в правом верхнем углу, чтобы найти его.

Шаг 5: Обновите Gradle

Откройте панель Gradle из верхней правой части окна IntelliJ и «Перезагрузите все проекты Gradle».

В идеале к концу этого шага на консоли IntelliJ должно быть напечатано сообщение «Сборка выполнена успешно». Это может занять несколько минут.

Шаг 7: Загрузите модель XGBoost в Java

Вызовите метод loadModel XGBoost, чтобы загрузить предварительно сохраненную модель в Java и сохранить ее в объекте типа Booster.

Создайте новый объект DMAtrix и передайте в его конструктор путь к ранее сохраненному файлу libsvm.

Прогнозы модели будут возвращены в виде 2D-матрицы. N-е значение хранится в ячейке [N][0] матрицы.

Обратитесь к коду в этом репо для справки.