Машинное обучение для домашнего задания аналитика данных
Классификация
Набор данных звездной классификации
загрузить набор данных
library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.5 v purrr 0.3.4 ## v tibble 3.1.6 v dplyr 1.0.8 ## v tidyr 1.2.0 v stringr 1.4.0 ## v readr 2.1.2 v forcats 0.5.1
## -- Conflicts ------------------------------------------ tidyverse_conflicts() -- ## x dplyr::filter() masks stats::filter() ## x dplyr::lag() masks stats::lag()
library(caret)
## Loading required package: lattice
## ## Attaching package: 'caret'
## The following object is masked from 'package:purrr': ## ## lift
library(mlbench) library(rpart) library(rpart.plot)
star_class <- read_csv("star_classification.csv")
## Rows: 100000 Columns: 18
## -- Column specification -------------------------------------------------------- ## Delimiter: "," ## chr (1): class ## dbl (17): obj_ID, alpha, delta, u, g, r, i, z, run_ID, rerun_ID, cam_col, fi... ## ## i Use `spec()` to retrieve the full column specification for this data. ## i Specify the column types or set `show_col_types = FALSE` to quiet this message.
пропущенные значения (полные?)
mean(complete.cases(star_class))
## [1] 1
изучить данные
## Rows: 100,000
## Columns: 18
## $ obj_ID <dbl> 1.237661e+18, 1.237665e+18, 1.237661e+18, 1.237663e+18, 1.~
## $ alpha <dbl> 135.689107, 144.826101, 142.188790, 338.741038, 345.282593~
## $ delta <dbl> 32.4946318, 31.2741849, 35.5824442, -0.4028276, 21.1838656~
## $ u <dbl> 23.87882, 24.77759, 25.26307, 22.13682, 19.43718, 23.48827~
## $ g <dbl> 22.27530, 22.83188, 22.66389, 23.77656, 17.58028, 23.33776~
## $ r <dbl> 20.39501, 22.58444, 20.60976, 21.61162, 16.49747, 21.32195~
## $ i <dbl> 19.16573, 21.16812, 19.34857, 20.50454, 15.97711, 20.25615~
## $ z <dbl> 18.79371, 21.61427, 18.94827, 19.25010, 15.54461, 19.54544~
## $ run_ID <dbl> 3606, 4518, 3606, 4192, 8102, 8102, 7773, 7773, 3716, 5934~
## $ rerun_ID <dbl> 301, 301, 301, 301, 301, 301, 301, 301, 301, 301, 301, 301~
## $ cam_col <dbl> 2, 5, 2, 3, 3, 3, 2, 2, 5, 4, 3, 4, 2, 2, 3, 4, 5, 5, 5, 5~
## $ field_ID <dbl> 79, 119, 120, 214, 137, 110, 462, 346, 108, 122, 27, 112, ~
## $ spec_obj_ID <dbl> 6.543777e+18, 1.176014e+19, 5.152200e+18, 1.030107e+19, 6.~
## $ class <chr> "GALAXY", "GALAXY", "GALAXY", "GALAXY", "GALAXY", "QSO", "~
## $ redshift <dbl> 6.347936e-01, 7.791360e-01, 6.441945e-01, 9.323456e-01, 1.~
## $ plate <dbl> 5812, 10445, 4576, 9149, 6121, 5026, 11069, 6183, 6625, 24~
## $ MJD <dbl> 56354, 58158, 55592, 58039, 56187, 55855, 58456, 56210, 56~
## $ fiber_ID <dbl> 171, 427, 299, 775, 842, 741, 113, 15, 719, 232, 525, 855,~
преобразование в фактор и использование только класса GALAXY и STAR
star_class <- star_class %>%
filter(class == c("GALAXY", "STAR")) %>%
mutate_if(is.character, as.factor)
класс подсчета
## # A tibble: 2 x 3
## class n percent
## <fct> <int> <dbl>
## 1 GALAXY 29622 0.734
## 2 STAR 10752 0.266
базовый прогноз, если мы предсказываем «ГАЛАКТИКА» для всех точек базового прогноза = 73,4%
1. разделить данные
set.seed(42)
id <- createDataPartition(y = star_class$class,
p = 0.8,
list = F)
train_df <- star_class[id, ]
test_df <- star_class[-id, ]
2. модель поезда
set.seed(42) ctrl <- trainControl(method = "cv", number = 5, classProbs = T, summaryFunction = twoClassSummary, verboseIter = T) rpart_model <- train(form = class ~ ., data = train_df, method = "rpart", trControl = ctrl)
## Warning in train.default(x, y, weights = w, ...): The metric "Accuracy" was not ## in the result set. ROC will be used instead.
## + Fold1: cp=0.0003488 ## - Fold1: cp=0.0003488 ## + Fold2: cp=0.0003488 ## - Fold2: cp=0.0003488 ## + Fold3: cp=0.0003488 ## - Fold3: cp=0.0003488 ## + Fold4: cp=0.0003488 ## - Fold4: cp=0.0003488 ## + Fold5: cp=0.0003488 ## - Fold5: cp=0.0003488 ## Aggregating results ## Selecting tuning parameters ## Fitting cp = 0.000349 on full training set
rf_model <- train(form = class ~ ., data = train_df, method = "rf", metric = "ROC", trControl = ctrl)
## + Fold1: mtry= 2 ## - Fold1: mtry= 2 ## + Fold1: mtry= 9 ## - Fold1: mtry= 9 ## + Fold1: mtry=17 ## - Fold1: mtry=17 ## + Fold2: mtry= 2 ## - Fold2: mtry= 2 ## + Fold2: mtry= 9 ## - Fold2: mtry= 9 ## + Fold2: mtry=17 ## - Fold2: mtry=17 ## + Fold3: mtry= 2 ## - Fold3: mtry= 2 ## + Fold3: mtry= 9 ## - Fold3: mtry= 9 ## + Fold3: mtry=17 ## - Fold3: mtry=17 ## + Fold4: mtry= 2 ## - Fold4: mtry= 2 ## + Fold4: mtry= 9 ## - Fold4: mtry= 9 ## + Fold4: mtry=17 ## - Fold4: mtry=17 ## + Fold5: mtry= 2 ## - Fold5: mtry= 2 ## + Fold5: mtry= 9 ## - Fold5: mtry= 9 ## + Fold5: mtry=17 ## - Fold5: mtry=17 ## Aggregating results ## Selecting tuning parameters ## Fitting mtry = 9 on full training set
3. предсказать тестовый набор (оценка)
p_rp <- predict(rpart_model, newdata = test_df)
p_rf <- predict(rf_model, newdata = test_df)
4. оценить матрицу в каретке
mean(p_rp == test_df$class)
## [1] 0.9988853
mean(p_rf == test_df$class)
## [1] 0.9990092
5. матрица путаницы в каретке
confusionMatrix(p_rp, test_df$class, mode = "prec_recall")
## Confusion Matrix and Statistics ## ## Reference ## Prediction GALAXY STAR ## GALAXY 5915 0 ## STAR 9 2150 ## ## Accuracy : 0.9989 ## 95% CI : (0.9979, 0.9995) ## No Information Rate : 0.7337 ## P-Value [Acc > NIR] : < 2.2e-16 ## ## Kappa : 0.9972 ## ## Mcnemar's Test P-Value : 0.007661 ## ## Precision : 1.0000 ## Recall : 0.9985 ## F1 : 0.9992 ## Prevalence : 0.7337 ## Detection Rate : 0.7326 ## Detection Prevalence : 0.7326 ## Balanced Accuracy : 0.9992 ## ## 'Positive' Class : GALAXY ##
confusionMatrix(p_rf, test_df$class, mode = "prec_recall")
## Confusion Matrix and Statistics ## ## Reference ## Prediction GALAXY STAR ## GALAXY 5918 2 ## STAR 6 2148 ## ## Accuracy : 0.999 ## 95% CI : (0.998, 0.9996) ## No Information Rate : 0.7337 ## P-Value [Acc > NIR] : <2e-16 ## ## Kappa : 0.9975 ## ## Mcnemar's Test P-Value : 0.2888 ## ## Precision : 0.9997 ## Recall : 0.9990 ## F1 : 0.9993 ## Prevalence : 0.7337 ## Detection Rate : 0.7330 ## Detection Prevalence : 0.7332 ## Balanced Accuracy : 0.9990 ## ## 'Positive' Class : GALAXY ##
Заключение Я выберу модель rpart, потому что она не сильно отличается от модели randomForest, но работает быстрее, чем модель randomForest.