Как получить и интерпретировать объяснения предсказаний

BigQuery ML - это простой в использовании способ вызова моделей машинного обучения для структурированных данных с использованием только SQL. Хотя он начинался только с линейной регрессии, были добавлены более сложные модели, такие как Deep Neural Networks и AutoML Tables, путем соединения BigQuery ML с TensorFlow и Vertex AI в качестве бэкэнда. Другими словами, хотя мы пишем SQL, выполняется TensorFlow.

BigQuery ML перенимает больше функций из серверной части Vertex AI *. В предыдущем посте я показал вам настройку гиперпараметров. В этом посте я покажу вам объяснимость.

Что такое объяснимость?

Объяснимость - это способ понять, что делает модель машинного обучения. Есть два типа объяснимости. При локальной объяснимости мы спрашиваем модель машинного обучения, как она дала результат для индивидуального прогноза. Почему сказано, что этот прокат продлится 2000 секунд? Почему предполагается, что эта транзакция является мошеннической?

В глобальной объяснимости мы спрашиваем о важности функции. Насколько важно время суток для прогнозирования продолжительности аренды? Насколько важна сумма транзакции для прогнозирования мошенничества?

BigQuery ML поддерживает оба типа объяснений, но, судя по моему опыту, конечные пользователи хотят объяснения на местном уровне, поэтому я сосредоточусь на этом.

Регрессионная модель

Я воспользуюсь простой моделью для прогнозирования продолжительности аренды велосипеда, которую я использую в своей книге (BigQuery: The Definitive Guide):

CREATE OR REPLACE MODEL ch09eu.bicycle_linear
TRANSFORM(
  * EXCEPT(start_date),
  CAST(EXTRACT(dayofweek from start_date) AS STRING)
         as dayofweek,
  CAST(EXTRACT(hour from start_date) AS STRING)
         as hourofday
)
OPTIONS(model_type='linear_reg', input_label_cols=['duration'])
AS
SELECT
  duration
  , start_station_name
  , start_date    
FROM 
  `bigquery-public-data`.london_bicycles.cycle_hire

Объяснение прогноза

Вместо вызова ML.PREDICT мы вызываем ML.EXPLAIN_PREDICT, чтобы получить объяснения наряду с прогнозами:

SELECT * 
FROM ML.EXPLAIN_PREDICT(
    MODEL `ai-analytics-solutions.ch09eu.bicycle_linear`, 
    (
        SELECT 'Park Street, Bankside' AS start_station_name,
        CURRENT_TIMESTAMP() AS start_date
    )    
)

Результат:

Прогнозируемая продолжительность - 1929 секунд.

Модель начинается с базового значения прогноза 4598 секунд - это среднее значение прогноза по всему набору обучающих данных. Из второго ввода (метка времени) модель извлекла час дня и день недели.

Поскольку день недели - среда, модель добавляет 247 секунд к базовому значению. Другими словами, по средам люди берут напрокат велосипеды немного дольше, чем в среднем.

Но тогда мы смотрим на 4 часа ночи. А это приводит к тому, что продолжительность аренды резко снижается, на 2800 секунд!

Наконец, Parkside - это станция, где арендная плата обычно ниже. Итак, модель применяет еще одну поправку, на этот раз 116 секунд.

4598 + 247–2800–116 = 1929

Модель классификации

В качестве второго примера возьмем модель прогнозирования мошенничества:

CREATE OR REPLACE MODEL advdata.ulb_fraud_detection 
TRANSFORM(
    * EXCEPT(Amount),
    SAFE.LOG(Amount) AS log_amount
)
OPTIONS(
    INPUT_LABEL_COLS=['class'],
    AUTO_CLASS_WEIGHTS = TRUE,
    DATA_SPLIT_METHOD='seq',
    DATA_SPLIT_COL='Time',
    MODEL_TYPE='logistic_reg'
) AS
SELECT 
 *
FROM `bigquery-public-data.ml_datasets.ulb_fraud_detection`

Опять же, мы используем ML.EXPLAIN_PREDICT:

SELECT * 
FROM ML.EXPLAIN_PREDICT(
    MODEL `ai-analytics-solutions.advdata.ulb_fraud_detection`, 
    (
        SELECT *
        FROM `bigquery-public-data.ml_datasets.ulb_fraud_detection`
        WHERE Amount BETWEEN 300 AND 305
        LIMIT 1
    )    
)

Результат имеет прогнозируемый_класс мошенничества, равный 0 (отсутствие мошенничества) с вероятностью 0,946. Это соответствует значению логитов, равному -2,87 (вероятность - это сигмоид логитов).

К сожалению, этот набор данных не сообщает нам, что это за столбцы, но базовый показатель равен -2,70, что указывает на то, что в значениях данных нет ничего необычного.

Давайте попробуем другой прогноз, на этот раз выбрав транзакцию с очень высокой стоимостью:

SELECT * 
FROM ML.EXPLAIN_PREDICT(
    MODEL `ai-analytics-solutions.advdata.ulb_fraud_detection`, 
    (
        SELECT *
        FROM `bigquery-public-data.ml_datasets.ulb_fraud_detection`
        WHERE Amount > 10000
        LIMIT 1
    )    
)

Теперь модель предсказывает, что транзакция, скорее всего, будет мошеннической. Значение logits - 12,665, тогда как базовое значение - -2,7. Ключевыми факторами резкого увеличения вероятности этой транзакции являются V4 (что добавляет 10,67) и V14 (что снижает ее на 4). Конечно, это значило бы больше, если бы мы знали, что V4 и V14 были… но в общедоступном наборе финансовых данных такой информации нет.

Как видите, вы можете использовать объяснения экземпляров, чтобы объяснить, как разные входные данные вызвали изменение результирующего значения. Это может помочь укрепить уверенность в том, что модель работает правильно.

Наслаждаться!

* Объясняемость реализована наиболее эффективно для каждого типа модели. В некоторых случаях вызывается Vertex AI, а в других - в самом BigQuery. Но это деталь реализации - независимо от того, как она реализована, вы бы назвали это одинаково.