Изучение несбалансированной классификации с помощью Tidymodels

Представьте, что вы специалист по данным в крупном многонациональном банке, и директор по работе с клиентами обращается к вам с просьбой разработать средство прогнозирования оттока клиентов. Вы создаете набор данных моментального снимка 10 000 клиентов с дисбалансом классов 1:4 в пользу клиентов, которые не уходят, чтобы использовать такую ​​модель бинарной классификации. Чтобы помочь в разработке модели, вы решаете исследовать различные методы выборки, которые могут помочь с дисбалансом классов.

В этой серии статей, состоящей из нескольких частей, будут рассмотрены следующие темы процесса разработки и объяснения модели с упором на преобразование результатов модели в бизнес-результаты и информирование заинтересованных сторон старшего уровня.

Конечная цель состоит в том, чтобы создать модель, которая позволит банку ориентироваться на текущих клиентов, которых можно классифицировать как отток клиентов, и применять определенные меры для предотвращения этого оттока. Вмешательства имеют свою цену, поэтому мы будем стремиться сбалансировать частоту ложноотрицательных результатов с частотой ложноположительных результатов. Мы разработаем функцию стоимости и пороговый анализ в Части 2 с пакетом вероятно.

Часть 3 будет посвящена пониманию переменного пространства, в котором преобладает отток, и позволит нам понять на локальном и глобальном уровнях ведущие факторы, которые приводят к оттоку клиентов, с помощью пакетов XAI DALEX/DALEXtra.

В части 4 мы воспользуемся другим подходом и применим методы анализа выживаемости к этому набору данных, поскольку он подвергается правильной цензуре и показывает время до события (должность) и результат с использованием недавно опубликованных функций анализа выживаемости в tidymodels.

Загрузить пакеты

library(tidymodels)
library(themis) #Recipe functions to deal with class imbalances
library(tidyposterior) #Bayesian Resampling Comparisons
library(baguette) #Bagging Model Specifications
library(corrr) #Correlation Plots
library(readr) #Read .csv Files
library(magrittr) #Pipe Operators
library(stringr) #String Manipulation
library(forcats) #Handling Factors
library(skimr) #Quick Statistical EDA
library(patchwork) #Create ggplot Patchworks
library(GGally) #Pair Plots
options(yardstick.event_first = FALSE) #Evaluate second factor level as factor of interest for yardstick metrics

Загрузить данные

Данные берутся с https://www.kaggle.com/shivan118/churn-modeling-dataset (Лицензия CC0: Public Domain) и загружаются, как показано ниже.

train <- read_csv("Churn_Modelling.csv") %>% 
  select(-c(Surname, RowNumber, CustomerId))

Исследовательский анализ данных

Мне нравится, что пакет skimr быстро предоставляет сводку всех переменных набора данных.

skim(train)

Наша целевая переменная Exited имеет приблизительное соотношение 4:1 между двумя возможными результатами, где Exited = 1 относится к оттоку клиентов. Чтобы визуализировать это, мы берем пользовательскую функцию ниже.

viz_by_dtype <- function (x,y) {
  title <- str_replace_all(y,"_"," ") %>% 
           str_to_title()
  if ("factor" %in% class(x)) {
    ggplot(train, aes(x, fill = x)) +
      geom_bar() +
      theme_minimal() +
      theme(legend.position = "none",
            axis.text.x = element_text(angle = 45, hjust = 1),
            axis.text = element_text(size = 8)) +
      scale_fill_viridis_d()+
      labs(title = title, y = "", x = "")
  }
  else if ("numeric" %in% class(x)) {
    ggplot(train, aes(x)) +
      geom_histogram()  +
      theme_minimal() +
      theme(legend.position = "none") +
      scale_fill_viridis_d() +
      labs(title = title, y = "", x = "")
  } 
  else if ("integer" %in% class(x)) {
    ggplot(train, aes(x)) +
      geom_histogram() +
      theme_minimal() +
      theme(legend.position = "none") +
      scale_fill_viridis_d()+
      labs(title = title, y = "", x = "")
  }
  else if ("character" %in% class(x)) {
    ggplot(train, aes(x, fill = x)) +
      geom_bar() +
      theme_minimal() +
      scale_fill_viridis_d() +
      theme(legend.position = "none",
            axis.text.x = element_text(angle = 45, hjust = 1),
            axis.text = element_text(size = 8)) +
      labs(title = title, y  ="", x= "")
  }
}
variable_list <- colnames(train) %>% as.list()
variable_plot <- map2(train, variable_list, viz_by_dtype) %>%
  wrap_plots(               
    ncol = 3,
    heights = 150,
    widths = 150)
ggsave("eda.png", dpi = 600)

Из вышеизложенного мы лучше понимаем распределение непрерывных переменных и количество дискретных переменных.

  • Кредитный рейтинг примерно нормально распределен
  • География разделена на три страны, с преобладанием Франции.
  • Пол разделен почти поровну
  • Возраст примерно смещен вправо, нормально распределен
  • Срок владения не имеет явного распределения, при этом основная часть клиентов остается на срок от 2 до 9 лет.
  • Баланс нормально распределяется при большом количестве клиентов с нулевым балансом
  • У большинства клиентов либо 1, либо 2 продукта.
  • Есть кредитная карта указывает, что 70% клиентов имеют кредитную карту
  • Is Active Member показывает, что 51,5% клиентов являются активными пользователями.
  • Предполагаемая заработная плата не имеет явного распределения

Двумерный числовой анализ

Теперь мы попытаемся понять, есть ли какая-либо связь между числовыми переменными, используя GGally::ggpairs().

ggpairs(train %>% 
          select(-c(HasCrCard,IsActiveMember,NumOfProducts, Gender, Geography)) %>% 
          drop_na() %>% 
          mutate(Exited = if_else(Exited == 1, "Y","N")), ggplot2::aes(color = Exited, alpha = 0.3)) + 
  scale_fill_viridis_d(end = 0.8, aesthetics = c("color", "fill")) + 
  theme_minimal() +
  labs(title = "Numeric Bivariate Analysis of Customer Churn Data")

Единственное, что стоит отметить в приведенном выше, это смещение вправо клиентов, которые ушли, что указывает на то, что у пожилых клиентов может быть больше шансов уйти.

Анализ категориальных переменных

Следующий шаг — установить, существуют ли какие-либо отношения между категориальными переменными и целью.

Ниже описано создание сводного фрейма данных, который вычисляет среднее значение и 95% доверительный интервал для каждой категориальной переменной и целевой переменной.

train %>% 
  mutate(Exited = if_else(Exited == 1, "Y", "N"),
         HasCrCard = if_else(HasCrCard == 1, "Y", "N"),
         IsActiveMember = if_else(IsActiveMember == 1, "Y", "N"),
         NumOfProducts = as.character(NumOfProducts)) %>% 
  select(Exited,where(is.character)) %>% 
  drop_na() %>% 
  mutate(Exited = if_else(Exited == "Y",1,0)) %>% 
  pivot_longer(2:6, names_to = "Variables", values_to = "Values") %>% 
  group_by(Variables, Values) %>% 
    summarise(mean = mean(Exited),
              conf_int = 1.96*sd(Exited)/sqrt(n())) %>% 
  ggplot(aes(x=Values, y=mean, color=Values)) +
    geom_point() +
    geom_errorbar(aes(ymin = mean - conf_int, ymax = mean + conf_int), width = 0.1) +
    theme_minimal() +
    theme(legend.position = "none",
        axis.title.x = element_blank(),
        axis.title.y = element_blank()) +
   scale_color_viridis_d(aesthetics = c("color", "fill"), end = 0.8) +
   facet_wrap(~Variables, scales = 'free') +
   labs(title = 'Categorical Variable Analysis', subtitle = 'With 95% Confidence Intervals')

Следует отметить, что мы видим, что пол, неактивное членство, количество продуктов и география демонстрируют значительную различную склонность к оттоку. И наоборот, наличие у клиента кредитной карты не оказывает существенного влияния на вероятность оттока. К этому следует относиться с небольшой долей скептицизма, учитывая дисбаланс классов.

Разработка модели

Разделение данных — rsample

Используя rsample::initial_split(), мы указываем разделение обучающих данных 3:1.

set.seed(246)
cust_split <- initial_split(train, prop = 0.75, strata = Exited)

Характеристики модели — Парнсип и багет

Мы собираемся ограничить диапазон спецификаций модели теми, которые представляют собой древовидные модели возрастающей сложности с использованием пакетов пастернака и багета (для bag_trees). Каждая модель указывает, что их соответствующие гиперпараметры настроены на настройку () для проверки на следующем шаге.

dt_cust <- 
decision_tree(cost_complexity = tune(), tree_depth = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")
rf_cust <- 
rand_forest(mtry = tune(), trees = tune(), min_n = tune()) %>% 
  set_engine("ranger", importance = "impurity") %>% 
  set_mode("classification")
xgboost_cust <- 
boost_tree(mtry = tune(), trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), loss_reduction = tune(), sample_size = tune())  %>% 
  set_engine("xgboost") %>% 
  set_mode("classification")
bagged_cust <- 
bag_tree(cost_complexity = tune(), tree_depth = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

Разработка функций — рецепты

Далее мы укажем этапы разработки функций с помощью пакета recipes. Мы будем использовать этот этап для разработки 10 рецептов, каждый из которых имеет разные методы выборки для обработки дисбаланса классов. Пакет themis содержит пошаговые рецепты для облегчения различных методов отбора проб.

  • SMOTE — синтетический метод передискретизации меньшинства
  • ROSE — Метод случайной передискретизации
  • BSMOTE — пограничный синтетический метод передискретизации меньшинства
  • UPSAMPLING — Добавьте повторяющиеся данные класса меньшинства к заданному соотношению с классом большинства
  • ADASYN — адаптивная синтетическая передискретизация
  • TOMEK — удалить ссылки TOMEK в мажоритарном классе
  • NEARMISS — удаление экземпляров класса большинства путем недостаточной выборки
  • NOSAMPLING — без процедуры отбора проб
  • SMOTE-DOWNSAMPLING — создание синтетических экземпляров меньшинства и удаление большинства экземпляров.
  • ROSE-DOWNSAMPLING — Передискретизируйте экземпляры меньшинства и понизьте выборку большинства
recipe_template <-
    recipe(Exited ~., data = training(cust_split)) %>% 
    step_integer(HasCrCard, IsActiveMember, zero_based = T) %>% 
    step_integer(NumOfProducts) %>% 
    step_mutate(SalaryBalanceRatio = EstimatedSalary/Balance,
              CreditScoreAgeRatio = CreditScore/Age,
              TenureAgeRatio = Tenure/Age,
              SalaryBalanceRatio = if_else(is.infinite(SalaryBalanceRatio),0,SalaryBalanceRatio)) %>% 
    step_scale(all_numeric_predictors(), -c(HasCrCard, Age, IsActiveMember, NumOfProducts)) %>% 
    step_dummy(all_nominal_predictors()) %>% 
    step_samplingmethod(Exited) #Change or Add Sampling Steps Here as Necessary

Выше мы разработали дополнительные функции, взяв частные из различных непрерывных переменных.

График корреляции — corrr

Мы визуализируем график корреляции всего обученного набора данных, используя рецепт без применения методов выборки.

cust_train <- recipe_8 %>% prep() %>% bake(new_data = NULL)
cust_test <- recipe_8 %>% prep() %>% bake(testing(cust_split))
cust_train %>% 
  bind_rows(cust_test) %>% 
  mutate(Exited = as.numeric(Exited)) %>% 
  correlate() %>%
  rplot(print_cor = T, .order = "alphabet") +
    scale_color_gradient2(low = 'orange', high = 'light blue') + 
    theme(axis.text.x = element_text(angle = 90)) +
    labs(title = "Correlation Plot for Trained Dataset")

Ничего удивительного в приведенных выше отрицательных корреляциях между возрастом и коэффициентами производных, которые мы сгенерировали. Exited имеет положительную корреляцию с возрастом, небольшую отрицательную корреляцию с CreditScoreAgeRatio и небольшую положительную корреляцию с SalaryBalanceRatio.

Карта рабочего процесса — наборы рабочих процессов

Используя пакет workflowsets, мы можем создать список из 40 рабочих процессов между комбинациями пастернака и рецептов И просмотреть 20 комбинаций гиперпараметров для каждого рабочего процесса. Мы создаем 5-кратный объект перекрестной проверки, используя rsample::vfold_cv(), используя слои kwarg для Exited, так что каждая складка имеет постоянное соотношение уровней целевой переменной.

recipe_list <- 
list(SMOTE = recipe_1, ROSE = recipe_2, BSMOTE = recipe_3, UPSAMPLE = recipe_4, ADASYN = recipe_5, TOMEK=recipe_6, NEARMISS = recipe_7, NOSAMPLING = recipe_8, SMOTEDOWNSAMPLE= recipe_9, ROSEDOWNSAMPLE = recipe_10)
model_list <- 
list(Decision_Tree = dt_cust, Boosted_Trees = xgboost_cust, Random_Forest = rf_cust, Bagged_Trees = bagged_cust)
wf_set <- 
workflow_set(preproc = recipe_list, models = model_list, cross = T)
set.seed(246)
train_resamples <- 
vfold_cv(training(cust_split), v = 5, strata = Exited)
class_metric <- metric_set(accuracy, f_meas, j_index, kap, precision, sensitivity, specificity, roc_auc, mcc, pr_auc)
doParallel::registerDoParallel(cores = 12)
wf_sample_exp <- 
  wf_set %>% 
  workflow_map(resamples = train_resamples, 
               verbose = TRUE, 
               metrics = class_metric, 
               seed = 246)

Мы используем функцию parsnip::metric_set() для создания пользовательского набора метрик для оценки. Эти пользовательские метрики передаются вместе с объектом wf_set в workflow_map для просмотра всех 40 рабочих процессов и вывода всех расчетов метрик для каждого рабочего процесса. Полученный объект workflow_set wf_sample_exp теперь можно проанализировать и использовать для сравнения моделей.

Это довольно сложно с вычислительной точки зрения, поэтому я рекомендую запустить все доступные ядра, чтобы облегчить это.

collect_metrics(wf_sample_exp) %>% 
  separate(wflow_id, into = c("Recipe", "Model_Type"), sep = "_", remove = F, extra = "merge") %>% 
  group_by(.metric) %>% 
  select(-.config) %>% 
  distinct() %>%
  group_by(.metric, wflow_id) %>% 
  filter(mean == max(mean)) %>% 
  group_by(.metric) %>% 
  mutate(Workflow_Rank =  row_number(-mean),
         .metric = str_to_upper(.metric)) %>% 
  arrange(Workflow_Rank) %>% 
  ggplot(aes(x=Workflow_Rank, y = mean, color = Model_Type)) +
    geom_point(aes(shape = Recipe)) +
    scale_shape_manual(values = 1:n_distinct(recipe_list)) +
    geom_errorbar(aes(ymin = mean-std_err, ymax = mean+std_err)) +
    theme_minimal() +
    scale_color_viridis_d() +
    labs(title = "Performance Comparison of Workflows", x = "Workflow Rank", y="Error Metric", color = "Model Types", shape = "Recipes") +
    facet_wrap(~.metric,scales = 'free_y',ncol = 4)

Цель вышеизложенного — продемонстрировать лишь пример метрик классификации, на которые можно посмотреть с несбалансированным набором данных. Нам нужна модель, которая в достаточной степени различает классы, но, учитывая нашу конкретную проблему, нам нужно свести к минимуму ложноотрицательные результаты (клиенты, которые уходят, но прогнозируется иначе). В части 2 обсуждается функция затрат для поиска баланса между стоимостью вмешательства и классовой дифференциацией.

Учитывая дисбаланс классов, ROC AUC и точность не являются подходящими метриками. Мы учитываем AUC Precision-Recall, KAP, J-индекс, коэффициент корреляции Мэтьюса и специфичность. Имея это в виду, UPSAMPLE_Boosted_Trees является сильным кандидатом с хорошими результатами по всем этим показателям.

Сравнение байесовских моделей — tidyposterior

Сосредоточившись на J-индексе, мы можем сравнить апостериорные распределения повторной выборки, используя tidyposterior. tidyposterior::perf_mod() берет объект wf_sample_exp, который содержит результаты из workflow_map, выполняет байесовское сравнение повторных выборок и генерирует апостериорные распределения интересующего метрического среднего. N.B объект workflow_set ДОЛЖЕН иметь рассчитанную целевую метрику, иначе это не сработает.

jindex_model_eval <- 
  perf_mod(wf_sample_exp, metric = "j_index", iter = 5000)
jindex_model_eval %>% 
  tidy() %>% 
  mutate(model = fct_inorder(model)) %>% 
  separate(model, into = c("Recipe", "Model_Type"), sep = "_", remove = F, extra = "merge") %>% 
  ggplot(aes(x=posterior, fill = Model_Type)) +
    geom_density(aes(alpha = 0.7)) +
    theme_minimal() +
    scale_fill_viridis_d(end = 0.8) +
    facet_wrap(~Recipe, nrow = 10) +
    labs(title = "Comparison of Posterior Distributions of Model Recipe Combinations", 
       x = expression(paste("Posterior for Mean J Index")), 
       y = "")

В приведенном выше много чего можно раскрыть, но вкратце это демонстрирует влияние процедуры выборки на обработку дисбаланса классов и влияние на J-Index. Это может быть выполнено для любой метрики, и я думаю, что это полезное упражнение для изучения влияния различных процедур повышения и понижения дискретизации. К процедуре выборки следует относиться как к гиперпараметру, и идеальный вариант будет отличаться в зависимости от характера набора данных, разработки функций и интересующих показателей. Мы отмечаем, что бустинг-деревья в целом работают хорошо, наряду с методами дерева в мешках, все зависит от процедуры выборки.

Заключительные замечания

Мы продемонстрировали процесс разработки модели для бинарного классификатора с несбалансированным набором данных 4:1. В нашем лучшем рабочем процессе используется модель XGBoost с процедурой повышения частоты дискретизации для выравнивания соотношения классов. Во второй части мы подгоним модель и завершим анализ порога принятия решения с использованием пакета вероятностей и разработаем два сценария — либо изменим порог принятия решения, чтобы минимизировать стоимость оттока и вмешательства, либо позволим лучше дифференцировать классы. Часть 2 будет в значительной степени сосредоточена на передаче результатов модели заинтересованным сторонам бизнеса и позволит им принять обоснованное решение в отношении затрат и риска оттока клиентов.

Спасибо за чтение и, пожалуйста, следите за последующими частями. Надеюсь, вам понравилось читать это так же, как мне понравилось писать. Если вы не являетесь участником Medium — используйте мою реферальную ссылку ниже и получайте регулярные обновления о новых публикациях от меня и других замечательных авторов Medium.