Прогнозирование выживаемости при сердечной недостаточности с помощью моделей машинного обучения - Часть II

Вторая часть пошагового руководства для анализа и прогнозирования выживаемости пациентов с сердечной недостаточностью.

Предисловие

В предыдущем посте мы рассмотрели набор данных о сердечной недостаточности у 299 пациентов, который включал в себя несколько особенностей образа жизни и клинических проявлений. Этот пост был посвящен исследовательскому анализу данных, а этот пост направлен на построение моделей прогнозирования.

Мотивация

Мотивирующий вопрос - «Каковы шансы на выживание пациента с сердечной недостаточностью?». В этом пошаговом руководстве я пытаюсь ответить на этот вопрос, а также даю несколько советов по работе с несбалансированными наборами данных.

Код этого проекта можно найти в моем репозитории GitHub.

Краткое резюме

В предыдущем посте мы видели, что -

  • Возраст и сывороточный креатинин имели слегка положительную корреляцию, тогда как сывороточный натрий и сывороточный креатинин имели слегка отрицательную корреляцию.
  • Большинство умерших пациентов не имели сопутствующих заболеваний или, самое большее, страдали анемией или диабетом.
  • Фракция выброса у умерших пациентов оказалась ниже, чем у выживших.
  • Уровень креатининфосфокиназы у умерших пациентов оказался выше, чем у выживших.

(Ознакомьтесь с предыдущей публикацией, чтобы ознакомиться с используемыми терминами)

Контур

  1. Работа с классовым дисбалансом
  2. Выбор модели машинного обучения
  3. Показатели эффективности
  4. Подготовка данных
  5. Стратифицированная k-кратная перекрестная проверка
  6. Построение модели
  7. Обобщение результатов

1. Работа с классовым дисбалансом

Прежде чем надевать каски, давайте взглянем на баланс целевых классов. Мы смотрим на долю умерших и выживших в исходном наборе данных.

print('% of heart failure patients who died = {}'.format(df.death.value_counts(normalize=True)[1]))
print('% of heart failure patients who survived = {}'.format(df.death.value_counts(normalize=True)[0]))
% of heart failure patients who died = 0.3210702341137124
% of heart failure patients who survived = 0.6789297658862876

Мы видим, что 32% больных умерли, а 68% выжили. Это явно несбалансированный набор данных !. В этом случае, какая бы модель мы ни выбрали, необходимо учитывать этот дисбаланс.

Работа с несбалансированными данными довольно распространена в реальном мире, и эти статьи German Lahera и на DataCamp - хорошее место, чтобы узнать о них.

Технический обзор решения этой проблемы выглядит следующим образом: Вы можете назначить штраф за неправильную классификацию класса меньшинства (тот, который имеет меньшую долю) и тем самым позволить алгоритму изучить это наказание. Другой подход заключается в использовании метода выборки: либо понижающая выборка класса большинства, либо передискретизация класса меньшинства, либо и то, и другое [1].

В нашем упражнении мы попытаемся справиться с этим дисбалансом:

  1. Использование метода стратифицированной k-кратной перекрестной проверки, чтобы убедиться, что совокупные показатели нашей модели не слишком оптимистичны (что означает: слишком хороши, чтобы быть правдой!) И отражать внутренний дисбаланс в данных обучения и тестирования;
  2. Использование модели со штрафными санкциями (вместо метода выборки, такого как SMOTE) с простой схемой взвешивания, которая является инверсией частоты класса.

Следуя этим шагам, мы увидим влияние дисбаланса на прогноз модели и попытаемся получить некоторые выводы!

2. Выбор модели машинного обучения.

В этом посте мы рассмотрим проблему как проблему контролируемой классификации и рассмотрим две основные линейные модели:

  1. Логистическая регрессия (LogReg)
  2. Машины опорных векторов (SVM)

Мы придерживаемся этих рабочих лошадок, потому что у них есть несколько хитрых приемов для работы с несбалансированными целевыми метками, и они просты для понимания. Не стесняйтесь пробовать другие алгоритмы, такие как Случайные леса, Деревья решений, Нейронные сети и т. Д., Среди контролируемых моделей и k- ближайшие соседи, DBSCAN и т. д. среди неконтролируемых моделей.

3. Показатели эффективности

Любая модель прогнозирования должна оцениваться по ее производительности с помощью определенных показателей прогнозирования. Прежде чем мы это сделаем, давайте определим наши типы дел -

  1. Истинно-положительные результаты (TP): когда модель предсказывает смерть, и пациент умер;
  2. Истинно отрицательные (TN): когда модель предсказывает выживаемость, и пациент выжил;
  3. Ложные срабатывания (FP): когда модель предсказывает смерть, но пациент выжил;
  4. Ложноотрицательные (FN): когда модель предсказывает выживаемость, но пациент умер.

Используя эти типы случаев, мы определяем следующие 5 показателей прогнозирования:

  1. Напоминание: это также известно как истинно положительный результат или чувствительность модели к истинно положительному результату. Он рассчитывается как TP / (TP + FN).
  2. Точность: это показатель того, насколько точны истинные положительные результаты, предсказанные моделью. Он рассчитывается как TP / (TP + FP).
  3. Точность. Это совокупный показатель общей производительности модели, рассчитываемый как (TP + TN) / (TP + TN + FP + FN).
  4. Сбалансированная точность: это совокупный показатель способности модели классифицировать каждый класс. Это среднее значение чувствительности (TPR) и специфичности (TNR) и выражается как (TPR + TNR) / 2.
  5. ROC AUC: это область под кривой рабочих характеристик приемника (ROC), которая генерируется на основе частоты истинных положительных и ложных положительных результатов для различных пороговых значений прогноза. Для случайного предиктора это значение равно 0,5, и наша модель должна быть лучше этого.

Определив эти метрики, важно обрисовать в общих чертах, какую производительность мы ожидаем от нашей модели. Мы ожидаем, что модель прогнозирования будет иметь -

  1. Высокая запоминаемость - модель должна предсказывать как можно больше смертей;
  2. Высокая точность - смерти, предсказываемые моделью, должны быть точными, т. е. как можно чаще совпадать с наблюдаемыми случаями смерти;
  3. Высокая сбалансированная точность - модель должна одинаково хорошо предсказывать случаи смерти и выживания, т. е. модель должна быть чувствительна к как можно большему количеству смертей и в то же время иметь конкретную характеристику смерти и выживаемости. прогнозы;
  4. Высокая точность - модель должна иметь высокую общую точность;
  5. AUC с высоким ROC. Общая площадь модели под кривой должна быть больше 0,5 любого случайного предиктора.

4. Подготовка данных

Масштабирование данных

Нашей основной подготовкой данных будет масштабирование функций. Мы делаем это с помощью числовых функций, потому что они измеряются в разных масштабах. Мы используем метод StandardScaler() в sklearn.preprocessing и масштабируем значения так, чтобы они имели среднее значение 0 и дисперсию 1.

cat_feat = df[['sex', 'smk', 'dia', 'hbp', 'anm']]
num_feat = df[['age', 'plt', 'ejf', 'cpk', 'scr', 'sna']]
predictors = pd.concat([cat_feat, num_feat],axis=1)
target = df['death']
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaled_feat = pd.DataFrame(scaler.fit_transform(num_feat.values),
                           columns = num_feat.columns)
scaled_predictors = pd.concat([cat_feat, scaled_feat], axis=1)

(Мы опускаем time функцию в текущем анализе)

5. Стратифицированная k-кратная перекрестная проверка

Быстрый праймер

Теперь эти модели, которые мы выбрали, обладают определенной степенью стохастичности, особенно когда дело доходит до решения для коэффициентов. Это означает, что наши результаты будут немного меняться каждый раз при запуске модели. Чтобы убедиться, что мы минимизируем эту случайность и предотвращаем недопустимую или чрезмерную подгонку, мы запускаем модель несколько раз и вычисляем среднее значение выбранной нами метрики.

k-кратная перекрестная проверка - хорошо известный метод выполнения этой итеративной проверки, особенно с небольшими наборами данных, которые могут не быть полностью репрезентативными для исследуемой совокупности. Набор данных разбивается на k подмножеств, и модель обучается на первых k-1 подмножествах и тестируется на последнем k-м подмножестве. Этот процесс повторяется k раз, и вычисляется среднее значение показателей эффективности [2].

Стратифицированная k-кратная перекрестная проверка помогает, когда целевые метки не сбалансированы. Поскольку обычная k-кратная перекрестная проверка несбалансированных целей может привести к тому, что несколько обучающих наборов будут иметь только одну целевую метку для обучения, выполняется стратификация. Другими словами, предыдущий процесс повторяется, но на этот раз, убедившись, что пропорция целевых меток сохраняется в каждом обучающем наборе [3] [4].

Мы используем StratifiedKFold и cross_validate из sklearn.model_selection для выполнения 10-кратной перекрестной проверки, после которой мы подсчитываем перечисленные показатели.

(Я считаю machinelearningmaster.com Джейсона Браунли чрезвычайно полезным ресурсом, чтобы узнать об этом больше)

6. Построение модели

Логистическая регрессия

Логистическая регрессия - это класс моделей линейной регрессии, который обычно подходит для прогнозирования двоичных результатов. Он выдает нелинейные выходные данные для линейных входных данных. В ее основе лежит логистическая функция (сигмовидная функция), и вероятности классов назначаются на основе этой функции после соответствующего обновления коэффициентов регрессии.

Чтобы подчеркнуть эффект предвзятости, возникающей из-за несбалансированных целевых классов, мы запускаем модель логистической регрессии с наложением штрафов и без них. Наказание может быть просто активировано class_weight=’balanced’ при создании экземпляра модели логистической регрессии.

#Stratified 8 fold cross validation
strat_kfold = StratifiedKFold(n_splits=10, shuffle=True)
#Instantiating the logistic regressor
logreg_clf = LogisticRegression() 
#To enable penalization, assign 'balanced' to the class_weight parameter
x = scaled_predictors.values
y = target.values
#Running the model and tallying results of stratified 10-fold cross validation
result = cross_validate(logreg_clf, x, y, cv=strat_kfold, scoring=['accuracy','balanced_accuracy', 'precision', 'recall', 'roc_auc'])                                                             

Мы смотрим на результаты прогнозов модели логистической регрессии без штрафов и штрафов.

pd.concat([pd.DataFrame(result1).mean(),
           pd.DataFrame(result2).mean()],axis=1).rename(columns={0:'Non-Penalized LogReg',1:'Penalized LogReg'})

Некоторые интересные наблюдения -

  • Общая точность двух моделей примерно одинакова и составляет ~ 72%, что достаточно хорошо.
  • Но когда мы посмотрим на сбалансированную точность, мы увидим большую разницу. Штрафной LogReg был чувствителен к обоим классам (71%), а LogReg без штрафных санкций был менее чувствителен (66%).
  • Точность находится на более низком уровне для LogReg со штрафными санкциями (54%), чем для LogReg без штрафов (67%), при этом значения недостаточно высоки.
  • Наибольшая нечувствительность к прыжку или отзыв до смертей наблюдается у штрафного LogReg (72%) по сравнению с нештрафованным LogReg (44%).
  • ROC AUC на уровне 0,76–0,77 по-прежнему лучше, чем случайный классификатор.

Классификатор опорных векторов

SVC - это непараметрические классификаторы, которые используют гиперплоскости в пространстве функций, чтобы попытаться разделить точки данных на классы, близкие друг к другу. Посмотрите это видео StatQuest на Youtube, чтобы получить четкое объяснение!

Из EDA в предыдущем посте мы увидели, что довольно много точек данных, классифицированных как умершие, находятся на периферии графиков разброса. Возможно, мы можем предположить, что линейное ядро ​​не сможет адекватно разделить эти точки данных, и вместо этого использовать ядро ​​радиальной базисной функции.

Мы создаем как штрафные, так и не штрафные SVC таким же образом, как и раньше, и учитываем смещение, которое возникает при прогнозировании несбалансированных классов.

#Stratified 10 fold cross validation
strat_kfold = StratifiedKFold(n_splits=10, shuffle=True)
#Instantiating the SVC 
svc_clf = SVC(kernel='rbf')
x = scaled_predictors.values
y = target.values
#Running the model and tallying results of stratified 10-fold cross validation
result3 = cross_validate(svc_clf, x, y, cv=strat_kfold, scoring=['accuracy','balanced_accuracy','precision','recall','roc_auc'])                                                                  

Мы сравниваем результаты прогнозирования двух вариантов модели SVC.

pd.concat([pd.DataFrame(result3).mean(),
           pd.DataFrame(result4).mean()],axis=1).rename(columns={0:'Non-Penalized SVC',1:'Penalized SVC'})

Некоторые интересные наблюдения -

  • Общая точность (74%) и сбалансированная точность (74%) оштрафованного SVC выше, чем SVC без штрафов.
  • Точность, в отличие от модели LogReg, находится на нижней стороне для обоих вариантов SVC.
  • Наибольший скачок чувствительности или отзыва к смерти наблюдается у SVC со штрафными санкциями (75%) по сравнению с SVC без штрафов (43%).
  • ROC AUC на уровне 0,77–0,80 все же лучше, чем случайный классификатор.

7. Обобщение результатов

В конце этого упражнения важно подвести итоги уже полученных результатов и разобраться в выводах, полученных на этом пути.

  • В этом наборе данных из 299 пациентов с сердечной недостаточностью 68% выжили, а 32% не выжили;
  • 5 характеристик образа жизни и 5 клинических особенностей характеризуют этот набор данных и использовались в качестве потенциальных предикторов выживаемости;
  • Большинство умерших не имели сопутствующих заболеваний, имели более низкую фракцию выброса и более высокие уровни креатининфосфокиназы, чем выжившие;
  • Когда традиционные модели линейной классификации, такие как логистическая регрессия и машины опорных векторов, используются для прогнозирования выживаемости, дисбаланс в наборе данных влияет на производительность;
  • 10-кратная перекрестная проверка и схема штрафов за обратную частоту улучшают характеристики прогнозирования этих моделей;
  • Штрафной SVC немного лучше, чем штрафной LogReg для прогнозирования смертей для этого набора данных;
  • Эти две модели обладают хорошей способностью (›70%) различать тех, кто может выжить, и тех, кто может умереть, используя 10 предоставленных функций.
  • Учитывая историю болезни пациента с сердечной недостаточностью (5 стилей жизни и 5 историй болезни), эти две модели имеют точность не менее 70% в прогнозировании выживаемости пациентов.

Некоторые интересные аспекты, которые могут повысить достоинства этого проекта: PCA и CATPCA для устранения сильно коррелированных функций, тестирование гиперпараметров, пробные модели машинного обучения без учителя и т. д.

Это конец этого проекта, и я надеюсь, что эти две публикации были вам полезны. Мы будем рады вашим отзывам!

Чао!

использованная литература

[1] https://statistics.berkeley.edu/sites/default/files/tech-reports/666.pdf

[2] https://machinelearningmaster.com/k-fold-cross-validation/

[3] https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html

[4] https://machinelearningmastery.com/k-fold-cross-validation/