利用XGBoost模型进行多分类任务下的SHAP解释附代码讲解及GUI展示

技术

picture.image

目标

在这篇文章中,我们将介绍如何利用XGBoost模型进行多分类任务,并使用SHAP对模型进行解释,并生成SHAP解释图、依赖图、力图和热图,从而直观地理解模型的决策过程和特征的重要性

二分类模型和多分类模型在SHAP上的差异

二分类模型

在二分类任务中,模型的目标是将数据划分为两个类别(例如,0和1),SHAP值用于解释每个特征对模型输出的贡献,在二分类模型中,每个样本的SHAP值只有一个,表示该特征对预测结果(通常是正类概率)的贡献

多分类模型

在多分类任务中,模型需要将数据划分为三个或更多类别,每个样本的预测结果不仅包含一个类别,还包括每个类别的概率,SHAP值在多分类任务中的应用需要分别计算每个类别的SHAP值,因此,对于每个样本,SHAP值将是一个矩阵,其中每个元素表示一个特征对某个类别的贡献

代码实现

数据读取处理


          
from sklearn.datasets import load_iris
          
import pandas as pd
          

          
# Load the iris dataset
          
iris = load_iris()
          

          
# Create a DataFrame from the iris dataset
          
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
          
iris_df['target'] = iris.target
          

          
from sklearn.model_selection import train_test_split
          
X = iris_df.drop(['target'],axis=1)
          
y = iris_df['target']
          

          
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
          

          
# 然后将训练集进一步划分为训练集和验证集
          
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.125,stratify=y_temp, random_state=42)  # 0.125 x 0.8 = 0.1
      

加载鸢尾花数据集,将鸢尾花数据集分割为训练集、验证集和测试集,具体过程是: 从 整个数据集中抽取20%作为测试集; 剩余的80%数据中抽取12.5%作为验证集,最终验证集占整个数据集的10%,训练集占整个数据集的70%, 这一步骤为后续使用XGBoost进行多分类模型的训练和评估奠定基础

模型建立


          
import xgboost as xgb
          

          
# 更新后的多分类模型参数
          
params_xgb = {
          
    'learning_rate': 0.02,            # 学习率
          
    'booster': 'gbtree',              # 提升方法
          
    'objective': 'multi:softprob',    # 损失函数,多分类使用softmax
          
    'num_class': 3,                   # 类别数,鸢尾花数据集有三类
          
    'max_leaves': 127,                # 每棵树的叶子节点数量
          
    'verbosity': 1,                   # 输出信息的详细程度
          
    'seed': 42,                       # 随机种子
          
    'nthread': -1,                    # 并行运算的线程数量
          
    'colsample_bytree': 0.6,          # 每棵树随机选择的特征比例
          
    'subsample': 0.7,                 # 每次迭代时随机选择的样本比例
          
    'early_stopping_rounds': 100,     # 早停轮数
          
    'eval_metric': 'mlogloss'         # 评估指标,多分类使用mlogloss
          
}
          

          
# 创建并训练多分类模型
          
model_xgb = xgb.XGBClassifier(**params_xgb)
          
model_xgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
      

picture.image

配置并训练一个XGBoost多分类模型来预测鸢尾花数据集的类别,使用特定的参数设置和早停机制,比如这里是三分类,如果需要更改为其他多分类(如四分类),只需修改参数num_class,并相应调整其他参数以达到最优模型效果

模型评价指标输出

评价报告


          
from sklearn.metrics import classification_report
          
# 预测测试集
          
y_pred = model_xgb.predict(X_test)
          

          
# 输出模型报告, 查看评价指标
          
print(classification_report(y_test, y_pred))
      

picture.image

混淆矩阵热力图


          
from sklearn.metrics import confusion_matrix
          
import seaborn as sns
          
import matplotlib.pyplot as plt
          
# 输出混淆矩阵
          
conf_matrix = confusion_matrix(y_test, y_pred)
          

          
# 绘制热力图
          
plt.figure(figsize=(10, 7))
          
sns.heatmap(conf_matrix, annot=True, annot_kws={'size':15}, fmt='d', cmap='YlGnBu')
          
plt.xlabel('Predicted Label', fontsize=12)
          
plt.ylabel('True Label', fontsize=12)
          
plt.title('Confusion matrix heat map', fontsize=15)
          
plt.show()
      

picture.image

Shap实现

创建Shap解释器


          
import shap
          
# 创建SHAP解释器
          
explainer = shap.Explainer(model_xgb)
          
# 计算SHAP值
          
shap_values = explainer(X_test)
          
print("shap值维度;",shap_values.shape)
          
shap_values
      

picture.image

可以看见针对测试集的shap值的维度为(30,4,3),也就是计算的每个类别的SHAP值,对于每个样本,SHAP值将是一个矩阵,其中每个元素表示一个特征对某个类别的贡献

绘制Shap解释图


          
# 特征标签
          
labels = X_train.columns
          

          
# 设置 matplotlib 的全局字体配置
          
plt.rcParams['font.family'] = 'serif'
          
plt.rcParams['font.serif'] = 'Times New Roman'
          
plt.rcParams['font.size'] = 13
          

          
# 提取每个类别的 SHAP 值
          
shap_values_class_1 = shap_values.values[:, :, 0]
          
shap_values_class_2 = shap_values.values[:, :, 1]
          
shap_values_class_3 = shap_values.values[:, :, 2]
          
shap_values_class_1
          

          
# 绘制 SHAP 总结图,使用viridis配色方案
          
plt.figure()
          
plt.title('class_1')
          
shap.summary_plot(shap_values_class_1, X_val, feature_names=labels, plot_type="dot", cmap="viridis")
          
plt.show()
      

picture.image

这里针对鸢尾花的测试集第一个类别0进行shap解释图绘制

绘制Shap依赖图


          
shap.dependence_plot('sepal length (cm)', shap_values_class_1, X_val, interaction_index='sepal width (cm)')
          
plt.show()
      

picture.image

针对鸢尾花的测试集 第一个类 别0的特征sepal length (cm)、sepal width (cm)进行 shap依赖图绘制

绘制Shap力图


          
# 选择一个样本索引进行解释
          
sample_index = 1
          
expected_value = explainer.expected_value[0]  # 需要指定个类别的基准值,这里是第一个类别
          
# 获取单个样本的 SHAP 值
          
sample_shap_values = shap_values_class_1[sample_index]
          

          
# 绘制 SHAP 解释力图 (Force Plot)
          
shap.force_plot(expected_value, sample_shap_values, X_val.iloc[sample_index], matplotlib=True)
          
# 显示绘图
          
plt.show()
      

picture.image

shap力图解释同样会选择数据集以及类别,但是还会多一个应选择的基准值,比如这里选择的第一个类比那基准值也要选择第一个类比的基准值

生成Shap交互作用图


          
shap_interaction_values = explainer.shap_interaction_values(X_val)
          
# 提取每个类别的值
          
shap_interaction_values_class_1 = shap_interaction_values[:, :, :, 0] # 类别1
          
shap_interaction_values_class_2 = shap_interaction_values[:, :, :, 1] # 类别2
          
shap_interaction_values_class_3 = shap_interaction_values[:, :, :, 2] # 类别3
          
# 绘制 SHAP 交互值的总结图
          
plt.figure()
          
shap.summary_plot(shap_interaction_values_class_1, X_val, feature_names=labels)
          
plt.show()
      

picture.image

计算的交互值比Shap值维度多一,同理得提取每一个类比的交互值,具体怎么提取参考这个三分类代码

生成Shap热图


          
expected_value = explainer.expected_value[0]  # 需要指定个类别的基准值,这里是第一个类别 
          
# 创建 shap.Explanation 对象
          
shap_explanation = shap.Explanation(values=shap_values_class_1[0:10, :], 
          
                                    base_values=  expected_value, 
          
                                    data=X_val.iloc[0:10, :], 
          
                                    feature_names=X_val.columns)
          

          
# 绘制热图
          
plt.figure()
          
shap.plots.heatmap(shap_explanation)
          
plt.show()
      

picture.image

代码绘制的是第一个类比测试集前10个样本的Shap热图,和力图一样得确定一个基准值,对哪一个类比做就采用哪一个类比的基准值

项目GUI实现

picture.image

创建一个使用Tkinter库构建的GUI应用程序,旨在通过按钮、标签、组合框和文本框等组件实现数据上传、选择目标特征、设置分类任务的类别数、选择数据集、选择颜色方案、选择特征、输入样本索引、输入样本范围等功能,从而对XGBoost分类模型进行训练并生成相关的解释图,并确保将这些图保存为高DPI的PDF文件,以保证可视化效果不受损失,如需获取添加以下微信进行交流

往期推荐

K-means聚类与t-SNE降维:多维数据的二维可视化

小白轻松上手:一键生成SHAP解释图的GUI应用,支持多种梯度提升模型选择

利用SHAP解释二分类模型的四种机器学习GUI工具

特征工程进阶:暴力特征字典的构建与应用 实现模型精度质的飞跃

基于CatBoost回归预测模型的多种可解释性图表绘制

快速选择最佳模型:轻松上手LightGBM、XGBoost、CatBoost和NGBoost!

无网络限制!同步官网所有功能!让编程小白也能轻松上手进行代码编写!!!

picture.image

picture.image

picture.image

微信号|deep_ML

欢迎添加作者微信进入Python、ChatGPT群

进群请备注Python或AI进入相关群
无需科学上网、同步官网所有功能、使用无限制

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
基于 ByteHouse 引擎的增强型数据导入技术实践
ByteHouse 基于自研 HaMergeTree,构建增强型物化 MySQL、HaKafka 引擎,实现数据快速集成,加速业务数据分析性能与效率,本次 talk 主要介绍物化 MySQL 与 HaKafka 数据导入方案和业务实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论