Как визуализировать модели деревьев решений с помощью этой полезной библиотеки
Когда дело доходит до объяснимости модели, деревья решений являются одними из самых интуитивно понятных и объяснимых моделей. Каждая модель дерева решений может быть объяснена как набор интерпретируемых человеком правил. Возможность визуализировать модели деревьев решений важна для объяснимости моделей и может помочь заинтересованным сторонам и бизнес-менеджерам завоевать доверие к этим моделям.
К счастью, мы можем легко визуализировать и интерпретировать деревья решений с помощью библиотеки dtreeviz. В этой статье я покажу, как можно использовать dtreeviz для визуализации древовидных моделей для регрессии и классификации.
Установка dtreeviz
Вы можете легко установить dtreeviz с помощью pip, используя следующую команду:
pip install dtreeviz
Подробный список зависимостей и дополнительных библиотек, которые могут потребоваться для установки в зависимости от вашей операционной системы, см. в этом репозитории GitHub.
Визуализация деревьев регрессии
В этом разделе мы обучим регрессор дерева решений на наборе данных о диабете. Обратите внимание, что вы можете найти весь код для этого руководства в этом репозитории GitHub. Имейте в виду, что я использую Jupyter в качестве среды для запуска этого кода Python. Вы можете найти весь код, который я написал для этого урока, в этом репозитории Github.
Импорт библиотек
В приведенном ниже блоке кода я просто импортировал несколько распространенных библиотек, включая модули scikit-learn DecisionTree и dtreeviz.
import numpy as np import pandas as pd from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor import dtreeviz
Чтение данных
Набор данных о диабете доступен в scikit-learn, поэтому мы можем использовать приведенный ниже код для импорта набора данных и сохранения функций и целевых значений в массивах numpy с именами X и y.
from sklearn.datasets import load_diabetes diabetes_data = load_diabetes() X = pd.DataFrame(data = diabetes_data['data'], columns=diabetes_data['feature_names']) y = diabetes_data['target']