Практическое руководство по XAI-анализу с помощью SHAP для задачи мультиклассовой классификации

Объясняемость модели становится основной частью конвейера машинного обучения. Сохранение модели машинного обучения в виде черного ящика больше не вариант. К счастью, существуют аналитические инструменты, такие как (Lime, ExplainerDashboard, Shapash, Dalex и другие), которые быстро развиваются и становятся все более популярными. В предыдущем посте мы объяснили, как использовать SHAP для решения задачи регрессии. В этом руководстве представлен практический пример того, как использовать и интерпретировать пакет Python с открытым исходным кодом, SHAP, для анализа XAI в задачах мультиклассовой классификации и использовать его для улучшения модели.

SHAP (Аддитивные объяснения Шепли) Лундберга и Ли (2016) - это метод объяснения индивидуальных прогнозов, основанный на теоретически оптимальных в игре значениях Шепли [1]. Расчет значения Шепли для получения дополнительных функций требует больших вычислительных ресурсов. Существует два метода аппроксимации значений SHAP для повышения эффективности вычислений: KernelSHAP, TreeSHAP (только для древовидных моделей).

SHAP предоставляет глобальные и локальные методы интерпретации, основанные на агрегировании значений Шепли. В этом руководстве мы будем использовать пример набора данных межсетевого экрана Интернета из наборов данных Kaggle [2], чтобы продемонстрировать некоторые выходные графики SHAP для задачи мультиклассовой классификации.

# load the csv file as a data frame
df = pd.read_csv('log2.csv')
y = df.Action.copy()
X = df.drop('Action',axis=1)

Создайте модель и подгоните ее, как всегда.

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, random_state=0)
cls = RandomForestClassifier(max_depth=2, random_state=0)
cls.fit(X_train, y_train)

Теперь, чтобы получить общее представление о модели, я рекомендую просмотреть средства импорта функций и матрицу путаницы. Чтобы понять, где мы находимся с важностью функции, я использовал scikit-learn, который вычисляет уменьшение примесей в каждом дереве [3].

֫importances = cls.feature_importances_
indices = np.argsort(importances)
features = df.columns
plt.title('Feature Importances')
plt.barh(range(len(indices)), importances[indices], color='g', align='center')
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel('Relative Importance')
plt.show()

Позже мы можем сравнить эти результаты с важностью характеристик, рассчитанной по значениям Шепли.

Матрица путаницы

Матрица неточностей - это способ визуализировать работу модели. И что еще более важно, мы можем легко увидеть, где именно модель не работает.

class_names = ['drop', 'allow', 'deny', 'reset-both']
disp = plot_confusion_matrix(cls, X_test, y_test, display_labels=class_names, cmap=plt.cm.Blues, xticks_rotation='vertical')

Модель не смогла обнаружить ни одного экземпляра из класса reset-both. Причина в несбалансированном наборе данных, который дает небольшое количество примеров класса reset-both, на которых можно учиться.

y.value_counts()
allow         37640
deny          14987
drop          12851
reset-both       54
Name: Action, dtype: int64

Сводный график SHAP

Значения SHAP выходных данных модели объясняют, как функции влияют на выходные данные модели.

# compute SHAP values
explainer = shap.TreeExplainer(cls)
shap_values = explainer.shap_values(X)

Теперь мы можем построить соответствующие графики, которые помогут нам проанализировать модель.

shap.summary_plot(shap_values, X.values, plot_type="bar", class_names= class_names, feature_names = X.columns)

На этом графике влияние объекта на классы суммируется, чтобы создать график важности объекта. Таким образом, если вы создали объекты для того, чтобы отличить определенный класс от остальных, вы можете увидеть его на этом графике. Другими словами, сводный график для мультиклассовой классификации может показать вам, что машине удалось узнать из функций.

В приведенном ниже примере мы видим, что отбрасывание класса почти не использует функции «График назначения», «Порт источника» и «Отправленные байты». Мы также можем видеть, что классы разрешают и запрещают одинаково использовать одни и те же функции. Это причина того, что путаница между ними относительно высока. Чтобы лучше разделить классы разрешения и запрета, необходимо создать новые функции, которые будут однозначно выделены для этих классов.

Вы также можете увидеть summary_plot определенного класса.

shap.summary_plot(shap_values[1], X.values, feature_names = X.columns)

Сводный график сочетает в себе важность функций с эффектами функций. Каждая точка на сводном графике представляет собой значение Шепли для объекта и экземпляра. Положение по оси Y определяется функцией, а по оси X - значением Шепли. Вы можете видеть, что функция pkts_sent, будучи наименее важной функцией, имеет низкие значения Шепли. Цвет представляет ценность функции от низкого до высокого. Перекрывающиеся точки колеблются в направлении оси Y, поэтому мы получаем представление о распределении значений Шепли для каждого объекта. Функции упорядочены по степени важности.

На сводном графике мы видим первые признаки взаимосвязи между ценностью характеристики и влиянием на прогноз. Но чтобы увидеть точную форму связи, мы должны взглянуть на графики зависимости SHAP.

График зависимости SHAP

График частичной зависимости (короткий график PDP или PD) показывает предельное влияние одной или двух характеристик на прогнозируемый результат модели машинного обучения (J. H. Friedman 2001 [3]). График частичной зависимости может показать, является ли связь между целью и элементом линейной, монотонной или более сложной.

График частичной зависимости - это глобальный метод: метод рассматривает все экземпляры и дает утверждение о глобальной связи признака с прогнозируемым результатом. PDP предполагает, что первая функция не коррелирует со второй функцией. Если это предположение нарушается, средние значения, рассчитанные для графика частичной зависимости, будут включать точки данных, которые очень маловероятны или даже невозможны.

График зависимости - это график разброса, который показывает влияние отдельной функции на прогнозы, сделанные моделью. В этом примере стоимость недвижимости значительно увеличивается, когда среднее количество комнат в доме превышает 6.

  • Каждая точка - это отдельный прогноз (строка) из набора данных.
  • Ось x - это фактическое значение из набора данных.
  • Ось Y - это значение SHAP для этой функции, которое показывает, насколько знание значения этой функции изменяет выходные данные модели для прогноза этой выборки.

Цвет соответствует второй функции, которая может иметь эффект взаимодействия с функцией, которую мы строим (по умолчанию эта вторая функция выбирается автоматически). Если между этой другой функцией и функцией, которую мы строим, присутствует эффект взаимодействия, он будет отображаться как отдельный вертикальный узор окраски.

֫# If we pass a numpy array instead of a data frame then we
# need pass the feature names in separately
shap.dependence_plot(0, shap_values[0], X.values, feature_names=X.columns)

В приведенном выше примере мы можем видеть четкий вертикальный узор окраски для взаимодействия между функциями, исходным портом и исходным портом NAT.

График SHAP Force

График силы дает нам объяснимость предсказания одной модели. На этом графике мы можем увидеть, как особенности повлияли на прогноз модели для конкретного наблюдения. Его очень удобно использовать для анализа ошибок или глубокого понимания конкретного случая.

i=8
shap.force_plot(explainer.expected_value[0], shap_values[0][i], X.values[i], feature_names = X.columns)

Из сюжета мы видим:

  1. Значение прогнозируемой_пробки модели: 0,79
  2. Базовое значение: это значение, которое можно было бы спрогнозировать, если бы мы не знали никаких функций для текущего экземпляра. Базовое значение - это среднее значение выходных данных модели по набору обучающих данных (объяснительное_ ожидаемое_значение в коде). В этом примере базовое значение = 0,5749.
  3. Цифры на стрелках графика обозначают значение функции для этого экземпляра. Истекшее время (с) = 5, а пакетов = 1
  4. Красным цветом обозначены особенности, которые повысили оценку модели, а синим - особенности, которые снизили оценку.
  5. Чем больше стрелка, тем больше влияние функции на вывод. Величину уменьшения или увеличения воздействия можно увидеть на оси абсцисс.
  6. Истекшее время 5 секунд увеличивает свойство, разрешенное классом, пакеты 6.546 уменьшают значение свойства.

Сюжет водопада SHAP

График водопада - это еще один график локального анализа для единичного прогноза. В качестве примера возьмем экземпляр номер 8:

row = 8
shap.waterfall_plot(shap.Explanation(values=shap_values[0][row], 
                                              base_values=explainer.expected_value[0], data=X_test.iloc[row],  
                                         feature_names=X_test.columns.tolist()))

  1. f (x) - значение прогнозируемой_пробы модели: 0,79.
  2. E [f (x)] - базовое значение = 0,5749.
  3. Слева указаны значения признаков, а на стрелках - вклад признаков в прогноз.
  4. Каждая строка показывает, как положительный (красный) или отрицательный (синий) вклад каждой функции перемещает значение из ожидаемого вывода модели по фоновому набору данных в вывод модели для этого прогноза [2].

Резюме

Фреймворк SHAP оказался важным достижением в области интерпретации моделей машинного обучения. SHAP объединяет несколько существующих методов для создания интуитивно понятного, теоретически обоснованного подхода к объяснению прогнозов для любой модели. Значения SHAP количественно определяют величину и направление (положительное или отрицательное) влияния функции на прогноз [6]. Я считаю, что анализ XAI с помощью SHAP и других инструментов должен быть неотъемлемой частью конвейера машинного обучения. Код в этом посте можно найти здесь.