В BlaBlaCar мы хотим, чтобы пользователи сосредоточились на том, чтобы делиться своими поездками, а не тратить время на ожидание загрузки страницы поиска. Для достижения этой цели команда Trip Search работает над созданием быстрого и точного поискового сервиса. Мы также экспериментируем с результатами поиска на основе машинного обучения, чтобы улучшить соответствие пользователей. Наша поисковая служба получает сотни запросов в секунду, а время ответа в среднем составляет 60 миллисекунд. Мы хотим, чтобы время ответа было как можно меньше. Однако сохранение низкого времени отклика и применение прогнозов машинного обучения для тысяч поездок для каждого из сотен поисковых запросов в секунду - сложная задача.

Готовые к производству решения для прогнозирования в реальном времени существуют как SaaS. Например, Amazon в режиме реального времени отвечает на большинство (отдельных) запросов прогнозов в течение 100 миллисекунд. Механизм машинного обучения Google имеет максимальный предел квоты - 10 000 прогнозов на каждые 100 секунд. Кроме того, у нас есть внутренний API машинного обучения, который отвечает примерно за 80 миллисекунд на предсказание набора выборок. Однако эти решения оказываются внешними по отношению к нашим производственным контейнерам. Это генерирует внешние запросы, добавляющие сетевое время, которое мы не можем сжать. Следовательно, время отклика предсказания увеличивается с размером выборки из-за процесса (де) сериализации.

Мы переходим к сервис-ориентированной архитектуре, в которой сервисы написаны на Java. Хотя экосистемы R и Python предоставляют бесчисленное количество фреймворков машинного обучения, в Java дело обстоит иначе. В BlaBlaCar люди, работающие над моделированием машинного обучения, используют R или Python, поэтому каким бы ни было наше решение, оно должно уметь читать модели из R и Python и выполнять прогнозы на Java. Хотя на виртуальной машине Java (JVM) работает все больше и больше языков, наше решение должно быть написано на Java, чтобы в полной мере использовать уже существующую экосистему Java.

Наша цель - иметь библиотеку, которую можно быстро интегрировать в наши сервисы и которая будет выполнять прогнозы на лету, поэтому запросы к сети не требуются. Конечным результатом является написанная мной библиотека Java, с которой мы экспериментируем. Он обеспечивает быстрые прогнозы машинного обучения, поскольку он предназначен для использования в вышеупомянутом сценарии, например, для поиска в реальном времени, предложений точек встречи и т. Д. Мы стремимся иметь время отклика прогноза около 10 мс для пакета выборок, например , 1000 поездок, точек встречи и т. Д. Библиотека сделана так, чтобы не зависеть от используемого механизма прогнозирования, в настоящее время поддерживающего Extreme Gradient Boosting (XGBoost).

В следующих разделах подробно описаны существующие базовые библиотеки, которые мы оценили, их плюсы и минусы, довольно простой рабочий процесс для каждой из них и тест, подходящий для нашего варианта использования. Это постоянная работа, тесты, графики и модели доступны здесь.

Обзор библиотек

R-Java

Плюсы

  • Фактический интерфейс между Java и R
  • В активном состоянии, прочный, более 10 лет разработки
  • Предварительная обработка может происходить в реальном времени в R

Минусы

  • При использовании каретки размер моделей, экспортированных из R, может вызвать ошибку OutOfMemoryError.
  • Никакое простое взаимодействие с Python не подразумевает использования такой библиотеки, как reticulate

Для каждого из примеров рабочего процесса библиотек, показанных ниже, будет «сторона R», на которой создается файл модели, и «сторона Java», на которой модель считывается и используется для прогнозирования. Конечно, модель также могла быть создана с использованием Python.

Используя R-Java, поток будет довольно простым и в то же время мощным, то есть подобрать модель, затем сохранить ее объектное представление R в файл, возможно, с возможностями предварительной обработки, заданными курсором, а затем прочитать его обратно в службу Java.

Сторона R

library(caret)
# Generate and export a model
model_fit <- glm(CLASS ~., data, family=binomial())saveRDS(model_fit, "model.R")

Сторона Java

import org.rosuda.JRI.Rengine;
private Rengine engine = Rengine.getMainEngine();
// Read model exported from the R side
String resource = getClass().getClassLoader().getResource("model.R").getFile();
engine.assign("model_file", resource);
engine.eval("library(caret)");
engine.eval("model_fit <- readRDS(model_file)");
// Perform the prediction
REXP dataFrame = engine.eval("as.data.frame(predict(model_fit, newdata = data, type=\"prob\"))");

JPMML

Плюсы

  • Современная реализация спецификации PMML
  • Предварительная обработка из коробки
  • Поддержка моделей из R, Python, Tensorflow,…
  • Интегрируется со Spark
  • Активно поддерживается, используется AirBnB

Минусы

  • В основном это проект одного человека, который создает огромную зависимость от единственного сопровождающего, хотя всегда можно войти в код и внести свой вклад.

Эта библиотека предоставляет отличное готовое решение из-за отсутствия взаимодействия между экосистемами машинного обучения на Python и R с Java. Модели могут быть экспортированы из отдельных функций R, таких как glm, в библиотеки, такие как каретка, scikit-learn Python и т. Д. Экспорт моделей из R представляет собой рабочий процесс, аналогичный, например, R-Java.

Сторона R

library(r2pmml)
# Generate and export a model
model_fit <- glm(CLASS ~., data, family=binomial())
r2pmml(model_fit, "model.pmml"))

Сторона Java

import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;

// Read model exported from the R side
InputStream resource = PMMLPredictor.class.getResourceAsStream("model.pmml");
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(resource);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
// Set the sample's dummy data
Map<FieldName, FieldValue> samples;
// Perform the prediction
predictions = evaluator.evaluate(samples);

XGBoost Predictor

Плюсы

  • Реализация XGBoost на Java
    Поддержка моделей из R и Python
  • Интегрируется со Spark

Минусы

  • Ограничения для моделей XGBoost
  • Только прогнозирование по выборке
  • Без предварительной обработки

Эта библиотека предоставляет уникальный движок XGBoost. Хотя он не дает свободы выбора движков, таких как JPMML и R-Java, он не является препятствием, поскольку наша библиотечная архитектура была задумана так, чтобы охватить несколько существующих движков. Таким образом, если мы хотим добавить новые движки, мы можем подключить к ним другую библиотеку.

Сторона R

library(caret)
library(xgboost)
model_fit <- # E.g., standard XGBoost train with caret
bst <- xgboost:::xgb.Booster.check(model_fit$finalModel, saveraw = FALSE)
xgb.save(bst, fname = "model.xgb")

Сторона Java

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.util.FVec;
// Read model exported from the R side
InputStream resource = XGBoostPredictor.class.getResourceAsStream("model.xgb");
// Perform the (per-sample) prediction
for (int i = 0; i < numberOfSamples; i++) {
    // Fill dummy dense matrix
    double[] denseArray =      DoubleStream.generate(random::nextDouble).limit(numberOfColumns).toArray();
    double[] denseArray = DoubleStream.generate(random::nextDouble).limit(numberOfFeatures).toArray();
    featureVector = FVec.Transformer.fromArray(denseArray, false);
    double[] prediction = predictor.predict(featureVector);
 }

XGBoost4J

  • Реализация Java, обернутая вокруг реализации XGBoost на C ++
  • Активно поддерживается
  • Поддержка моделей из R и Python
  • Интегрируется со Spark

Минусы

  • Ограничения для моделей XGBoost
  • Без предварительной обработки
  • Полагается на собственную системную библиотеку

Те же замечания, что и для XGBoost Predictor. Более того, зависимость собственной библиотеки XGBoost4J может быть настоящей болевой точкой при использовании нескольких архивов. Рабочий процесс XGBoost4J такой же, как и XGBoost Predictor на стороне R, хотя сторона Java немного меняется, как показано ниже.

Сторона Java

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
private Booster predictor;
// Read model exported from the R side
InputStream resource = XGBoost4JPredictor.class.getResourceAsStream("model.xgb");
predictor = XGBoost.loadModel(resource);
float[] data = new float[numberOfFeatures * numberOfSamples];
// Fill dummy dense matrix
DMatrix matrix = new DMatrix(data, numberOfRows, numberOfColumns);
// Perform the prediction
float[][] predictions = predictor.predict(matrix);

Контрольный показатель

Я изучил 9 вариантов использования (UC), которые представляют собой комбинации фреймов данных со 100, 500 и 1000 строками (выборками) и столбцами (функциями) с использованием логистической регрессии (LR), деревьев повышения (GBM или XGBoost) для прогнозирования двух категориальные переменные класса. Эти UC - это то, что ожидается в масштабе наших сценариев использования в BlaBlaCar, например, до 1000 точек встречи для ранжирования, содержащих до 1000 функций.

Приведенные ниже результаты представлены с 95 ДИ, планки погрешностей представляют собой стандартную ошибку среднего (SEM) для 500 повторов. Для LR, GBM и XGBoost мы делаем прогнозы для этих 9 вариантов использования. В таблице ниже представлены вышеупомянутые варианты использования.

          ╔═══════════════════════════════════════════╗
          ║        Number of columns (features)       ║
          ╠═══════════════╦══════╦══════╦══════╦══════╣
          ║               ║      ║ 100  ║ 500  ║ 1000 ║ 
          ║   Number of   ║══════╬══════╬══════╬══════╣
          ║     rows      ║ 100  ║ UC 1 ║ UC 4 ║ UC 7 ║
          ║   (samples)   ║ 500  ║ UC 2 ║ UC 5 ║ UC 8 ║
          ║               ║ 1000 ║ UC 3 ║ UC 6 ║ UC 9 ║
          ╚═══════════════╩════════════════════╩══════╝

Единственная метрика, которую я исследовал, - это время прогноза. Это означает, что время для построения выборки данных не оценивалось, так как оно может сильно варьироваться в зависимости от требований. Например, составление отдельных функций, например, продолжительность поездки, связанная с ее отклонением для группы, может быть быстро вычислено, но если есть сотни производных функций, которые должны быть сгенерированы до прогнозирования, эта предварительная обработка, вероятно, займет гораздо больше время.

Поэтому, чтобы проиллюстрировать вычисление времени в одном из приведенных выше примеров Java, время прогнозирования рассчитывается следующим образом

// Fill dummy dense matrix
DMatrix matrix = new DMatrix(data, numberOfRows, numberOfColumns);
// Perform the prediction
Instant start = Instant.now();
float[][] predictions = predictor.predict(matrix);
Instant end = Instant.now();
long repeatDuration = Duration.between(start, end).toMillis();

Контрольный показатель логистической регрессии

Мы начали изучение моделирования машинного обучения в BlaBlaCar с помощью моделей логистической регрессии. Таким образом, я хотел увидеть, как каждая из библиотек ведет себя в каждой из объединенных коммуникаций. Для этого сравнения XGBoost был применен с objective = "binary:logistic”andnrounds = 1.

Приведенная выше диаграмма показывает, что R-Java и PMML менее эффективны в «более сложных» сценариях использования, то есть тех, которые содержат больше примеров и функций. Ниже у нас есть такая же диаграмма, но на этот раз с разбивкой по библиотеке, она ясно покажет, как каждый из них ведет себя.

Как проиллюстрировано выше, время прогнозирования увеличивается, когда необходимо прогнозировать больше данных на PMML и R-Java для тестируемых UC. Это не означает, что библиотеки XGBoost не следуют тому же правилу, но может оказаться, что исследованные здесь универсальные коммуникационные системы не так велики с точки зрения функций и образцов, как это необходимо, чтобы увидеть эту тенденцию и для них.

Если бы нашей целью было предсказать до сотни выборок только с помощью GLM, то есть случаев, охватываемых UC1, UC2 и UC3 (масштаб ниже), любая из библиотек была бы в порядке, около 12 миллисекунд вполне достаточно.

Чтобы завершить линейный тест, у нас есть библиотеки XGBoost, которые предсказывают намного быстрее, чем RJava и PMML. Помимо,

  • PMML очень разумно для увеличения функций и образцов
  • RJava очень чувствителен к увеличению семплов
  • Стабильное время работы XGBoost4J и XGBoost Predictor независимо от варианта использования

Тест повышения качества Tress

Улучшение деревьев не нуждается в представлении, они предоставляют мощный прогноз, как черный ящик. Ниже мы начнем с такой же диаграммы из линейного теста, чтобы получить представление о производительности для каждого объединенных коммуникаций. Для этого сравнения был применен XGBoost с objective = "binary:logistic”andnrounds = 150.

На этот раз время прогнозирования XGBoost Predictor и PMML увеличилось с увеличением количества выборок, рисунок ниже лучше иллюстрирует это с помощью фасетирования для каждой библиотеки.

Результаты показывают, что для улучшения деревьев XGBoost Predictor и PMML имели намного меньшее время прогнозирования за счет увеличения количества выборок.

  • PMML и XGBoost Predictor очень чувствительны к увеличению сэмплов. От 100 до 500 предполагалось увеличение времени прогнозирования примерно на 4,5. От 500 до 1000 предполагалось ~ 1,9 увеличение времени прогнозирования
  • Стабильное время работы XGBoost4J и RJava независимо от тестируемых вариантов использования
  • По крайней мере, для оцениваемых здесь объединенных коммуникаций XGBoost Predictor не работает так, как утверждает в своем репозитории [XGBoost Predictor] примерно в 6000–10 000 раз быстрее, чем XGBoost4J в задачах прогнозирования. На самом деле он работал хуже, чем XGBoost4J. Подобные результаты были упомянуты в этом обсуждении.

Дизайн нашей библиотеки

По мере запуска новых проектов с использованием машинного обучения люди могут не захотеть начинать с более сложных моделей повышения. Таким образом, GLM также должна поддерживаться библиотекой, поэтому здесь необходимо учитывать вышеупомянутые результаты.

XGBoost4J был выбран в качестве первой базовой библиотеки нашей внутренней библиотеки, учитывая, что скорость прогнозирования является нашей основной целью.

Хотя библиотека еще не открыта, поскольку мы интегрируем и тестируем ее внутри, она используется следующим образом

// Read model exported from the R side and create predictor
InputStream modelResource = getClass().getClassLoader().getResourceAsStream("model.XGB");
MachineLearningOperationBuilder predictor = MachineLearning.newXGB(modelResource);
// Tuple's key help to glue with the business object later and have a list of features
Tuple featTuple1 = new Tuple("trip1", Lists.newArrayList(0f, 1f, 3f, 5f));
Tuple featTuple2 = new Tuple("trip2", Lists.newArrayList(0f, 2f, 5f, 5f));
Tuple featTuple3 = new Tuple("trip3", Lists.newArrayList(0f, 3f, 7f, 5f));
// Create data frame from the given tuples
FeatureDataFrame featureDataFrame = new FeatureDataFrame(Lists.newArrayList(featTuple1, featTuple2, featTuple3));
// Perform prediction
ProbabilityDataFrame probDataFrame = predictor.with(featureDataFrame)
          .predict();

Рассматривая задачу с 3 классами, probDataFrame будет содержать что-то в строках отрывка ниже, то есть одну вероятность на класс. Затем можно создать компаратор, чтобы отсортировать их по вероятности принадлежности к определенному классу.

Tuple(id=trip1, features=[0.033174984, 0.0935384, 0.8732866])
Tuple(id=trip2, features=[0.05029368, 0.055675365, 0.894031])
Tuple(id=trip3, features=[0.12031816, 0.27852917, 0.60115266])

В настоящее время библиотека поддерживает постобработку с помощью лямбда-выражения, передаваемого методу score, который под капотом применяет predict().score(lambda), т. Е. Связывает данную лямбду, как показано ниже.

PostProcessor<ProbabilityDataFrame> multiplier = (features) -> new ProbabilityDataFrame(
        features.getTuples()
                .stream()
                .map(row -> new Tuple(row.getId(),
                        Lists.newArrayList((float) row
                                .getFeatures()
                                .stream()
                                .mapToDouble(aFloat -> aFloat)
                                .sum()))
                ).collect(Collectors.toList()));

ProbabilityDataFrame probDataFrame = machineLearningOperation.with(featureDataFrame)
        .score(multiplier);

В приведенном выше примере суммируется каждая из вероятностей для каждого класса. Хотя он довольно надуманный, он иллюстрирует, как может быть реализована постобработка путем передачи функций, определенных клиентом, что дает полную гибкость пользователю библиотеки.

// Illustrative result of the function above
Tuple(id=trip1, features=[1])
Tuple(id=trip2, features=[1])
Tuple(id=trip3, features=[1])

Некоторые результаты

Результаты, приведенные ниже, показывают затраченное время для каждого из двух основных этапов выполнения рекомендации в BlaBlaCar, прогнозирования («прогнозировать») и обработки функций («функции»), например, сопоставление категории с соответствующим порядковым номером, используемым на этапе обучения, присвоение по умолчанию атрибутов с нулевым значением из входного запроса и т. д. Чтобы дать немного больше контекста, эта модель предсказывает, какие точки встречи следует предлагать для вождения при публикации поездки. Модель имеет 7 функций.

Выводы

В BlaBlaCar мы находимся в начале трансформации, в которой все больше и больше сервисов будут применять машинное обучение для повышения качества реакции, которую оно дает пользователям. Учитывая широкое распространение R и Python в этой области, не всегда легко выбрать лучший подход к созданию взаимодействия между ними и службами Java. Кроме того, если скорость прогнозирования является важным аспектом, это становится очень сложной задачей. Целью данной статьи было изучить возможности, рассматривающие скорость прогнозирования как главную цель.

Тем не менее, в зависимости от различных потребностей, например, менее ограниченного времени отклика, другие варианты более подходят. Например, если сотни миллисекунд времени предсказания не проблема, я рекомендую использовать PMML. Он предоставляет сверхширокий выбор библиотек, из которых можно читать экспортированные модели. Если те, кто обучает модели, используют исключительно R, я бы порекомендовал R-Java или PMML. Последнее связано с тем, что встраивание кода R в базу кода Java может быстро стать беспорядочным, если только строгий дизайн инкапсуляции не разделяет их в максимально возможной степени.

Хотя результаты тестов кажутся достаточно ясными, меня заинтриговал один последний момент, и у меня нет ему объяснений. При сравнении GBM и GLM от R-Java видно, что у первого более быстрое время отклика. Это поведение не связано с R-Java. Запустив такое же сравнение в среде R replicate(500, system.time(predict(model_fit, test))[3]), где model_fit - это та же модель, что и для эталонного теста R-Java, а test - тот же тест, мы получаем те же результаты, указывающие на то, что GLM медленнее в прогнозировании. Однако я не исследовал дальше.

Вот несколько четких следующих шагов:

  • Создайте конвейер предварительной обработки, чтобы иметь возможность справляться с простейшими задачами, такими как одноразовое кодирование и т. Д.
  • Улучшите интеграцию моделей, например, упростите синхронизацию порядка функций между людьми, которые обучают модели, и теми, кто внедряет ее в сервис. Довольно сложно поддерживать синхронизацию, когда у вас есть сотни функций
  • Продолжайте тестировать другие библиотеки, такие как Treelite, и последние версии библиотек, тестируемых в этой статье.