Как визуализировать модели деревьев решений с помощью этой полезной библиотеки

Когда дело доходит до объяснимости модели, деревья решений являются одними из самых интуитивно понятных и объяснимых моделей. Каждая модель дерева решений может быть объяснена как набор интерпретируемых человеком правил. Возможность визуализировать модели деревьев решений важна для объяснимости моделей и может помочь заинтересованным сторонам и бизнес-менеджерам завоевать доверие к этим моделям.

К счастью, мы можем легко визуализировать и интерпретировать деревья решений с помощью библиотеки dtreeviz. В этой статье я покажу, как можно использовать dtreeviz для визуализации древовидных моделей для регрессии и классификации.

Установка dtreeviz

Вы можете легко установить dtreeviz с помощью pip, используя следующую команду:

pip install dtreeviz

Подробный список зависимостей и дополнительных библиотек, которые могут потребоваться для установки в зависимости от вашей операционной системы, см. в этом репозитории GitHub.

Визуализация деревьев регрессии

В этом разделе мы обучим регрессор дерева решений на наборе данных о диабете. Обратите внимание, что вы можете найти весь код для этого руководства в этом репозитории GitHub. Имейте в виду, что я использую Jupyter в качестве среды для запуска этого кода Python. Вы можете найти весь код, который я написал для этого урока, в этом репозитории Github.

Импорт библиотек

В приведенном ниже блоке кода я просто импортировал несколько распространенных библиотек, включая модули scikit-learn DecisionTree и dtreeviz.

import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
import dtreeviz

Чтение данных

Набор данных о диабете доступен в scikit-learn, поэтому мы можем использовать приведенный ниже код для импорта набора данных и сохранения функций и целевых значений в массивах numpy с именами X и y.

from sklearn.datasets import load_diabetes

diabetes_data = load_diabetes()
X = pd.DataFrame(data = diabetes_data['data'], columns=diabetes_data['feature_names'])
y = diabetes_data['target']