Получить AUC данных обучения из адаптированного рабочего процесса в Tidymodels?

Я борюсь с тем, как получить AUC из модели логистической регрессии с использованием tidymodels.

Вот пример использования встроенного набора данных mpg.

library(tidymodels)
library(tidyverse)

# Use mpg dataset
df <- mpg

# Create an indicator variable for class="suv"
df$is_suv <- as.factor(df$class == "suv")

# Create the split object
df_split <- initial_split(df, prop=1/2)

# Create the training and testing sets
df_train <- training(df_split)
df_test <- testing(df_split)

# Create workflow
rec <-
  recipe(is_suv ~ cty + hwy + cyl, data=df_train)

glm_spec <-
  logistic_reg() %>%
  set_engine(engine = "glm")

glm_wflow <- 
  workflow() %>%
  add_recipe(rec) %>%
  add_model(glm_spec)

# Fit the model
model1 <- fit(glm_wflow, df_train)

# Attach predictions to training dataset
training_results <- bind_cols(df_train, predict(model1, df_train))

# Calculate accuracy
accuracy(training_results, truth = is_suv, estimate = .pred_class)

# Calculate AUC??
roc_auc(training_results, truth = is_suv, estimate = .pred_class)

Последняя строка возвращает эту ошибку:

> roc_auc(training_results, truth = is_suv, estimate = .pred_class)
Error in metric_summarizer(metric_nm = "roc_auc", metric_fn = roc_auc_vec,  : 
  formal argument "estimate" matched by multiple actual arguments

person max    schedule 12.04.2021    source источник


Ответы (1)


Поскольку вы выполняете двоичную классификацию, roc_auc() ожидает вектор вероятностей класса, соответствующий соответствующему классу, а не предсказанному классу.

Вы можете получить это с помощью predict(model1, df_train, type = "prob"). В качестве альтернативы, если вы используете рабочие процессы версии 0.2.2 или новее, вы можете использовать augment() для получения прогнозов и вероятностей классов без использования bind_cols().

library(tidymodels)
library(tidyverse)

# Use mpg dataset
df <- mpg

# Create an indicator variable for class="suv"
df$is_suv <- as.factor(df$class == "suv")

# Create the split object
df_split <- initial_split(df, prop=1/2)

# Create the training and testing sets
df_train <- training(df_split)
df_test <- testing(df_split)

# Create workflow
rec <-
  recipe(is_suv ~ cty + hwy + cyl, data=df_train)

glm_spec <-
  logistic_reg() %>%
  set_engine(engine = "glm")

glm_wflow <- 
  workflow() %>%
  add_recipe(rec) %>%
  add_model(glm_spec)

# Fit the model
model1 <- fit(glm_wflow, df_train)

# Attach predictions to training dataset
training_results <- augment(model1, df_train)

# Calculate accuracy
accuracy(training_results, truth = is_suv, estimate = .pred_class)
#> # A tibble: 1 x 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.795

# Calculate AUC
roc_auc(training_results, truth = is_suv, estimate = .pred_FALSE)
#> # A tibble: 1 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.879

Создано 12 апреля 2021 года пакетом REPEX (v1.0.0)

person EmilHvitfeldt    schedule 12.04.2021
comment
Я обновился до последней версии, но не могу заставить работать функцию augment. Вот ошибка: > training_results <- augment(model1, df_train) Error: No augment method for objects of class workflow - person max; 13.04.2021
comment
Ничего страшного, мне просто нужно было обновить свою версию workflows - person max; 13.04.2021