期刊配图:多分类模型SHAP如何通过热图进行模型整体或者单样本解读

机器学习大数据数据库

picture.image

✨ 欢迎关注 ✨

本节介绍: 多分类模型SHAP解释通过热图汇总柱状图显示 。数据采用模拟数据,作者根据个人对机器学习的理解进行代码实现与图表输出,仅供参考。 完整 数据和代码将在稍后上传至交流群,付费成员可在交流群中获取下载。需要的朋友可关注公众文末提供的购买方式。 购买前请咨询,避免不必要的问题。

✨ 文献信息 ✨

picture.image

文献中展示了一个SHAP特征重要性热图(多分类单样本解释),用于显示不同特征对分类模型的贡献(例如对于非小细胞肺癌的预测),并用颜色表示每个特征在不同类别(如“良性肿瘤”和“非小细胞肺癌”)中的重要性,数值越接近零表示该特征对模型影响较小

✨ 模拟实现结果 ✨

picture.image

和文献一样绘制SHAP值热图来可视化单一样本在不同矿床类型(5分类)中的特征重要性分布,采用的正是文献中常用的 per-class 特征重要性呈现方法,只是应用在多分类(five-class)任务中,便于深入分析每类的关键元素贡献

picture.image

同样可以应用到机器学习模型的整体可解释性上,可以通过展示各个特征(或元素)对模型预测结果的贡献程度。这通常用于特征重要性分析,帮助理解哪些特征在模型决策中起到了关键作用,为什么要用热图就是因为多分类模型shap解释针对每个类别都有一个一一对应的shap,而不是一个维度的shap值

✨ 代码实现 ✨

  
import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
plt.rcParams['font.family'] = 'Times New Roman'  
plt.rcParams['axes.unicode_minus'] = False  
from matplotlib.patches import Wedge  
import warnings  
# 忽略所有警告  
warnings.filterwarnings("ignore")  
from sklearn.model_selection import train_test_split  
df = pd.read_excel('2025-6-10公众号Python机器学习AI.xlsx')  
from sklearn.preprocessing import LabelEncoder  
label_encoder = LabelEncoder()  
df['Type_encoded'] = label_encoder.fit_transform(df['Type'])  
from sklearn.model_selection import train_test_split  
  
# 划分特征和目标变量  
X = df.drop(['Type' ,'Type_encoded'], axis=1)    
y = df['Type_encoded']    
# 划分训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(  
    X,    
    y,   
    test_size=0.3,   
    random_state=42,   
    stratify=df['Type_encoded']   
)  
from xgboost import XGBClassifier  
from sklearn.model_selection import GridSearchCV, StratifiedKFold  
  
# 定义模型  
xgb_clf = XGBClassifier(objective='multi:softmax', num_class=len(y.unique()), use_label_encoder=False, eval_metric='mlogloss')  
  
# 设置参数网格  
param_grid = {  
    'max_depth': [3, 5, 7],  
    'learning_rate': [0.01, 0.1, 0.2],  
    'n_estimators': [50, 100, 150],  
    'subsample': [0.7, 1.0]  
}  
  
# 设置K折交叉验证  
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)  
  
# 构建GridSearchCV对象  
grid_search = GridSearchCV(  
    estimator=xgb_clf,  
    param_grid=param_grid,  
    scoring='accuracy',  
    cv=cv,  
    n_jobs=-1,  
    verbose=1  
)  
  
# 拟合模型  
grid_search.fit(X_train, y_train)  
best_model = grid_search.best_estimator_

通过使用XGBoost分类器和GridSearchCV进行超参数调优,来训练一个多分类模型,目的是优化模型的准确率并选择最佳的超参数(如最大深度、学习率等)。首先对数据进行预处理、划分训练集和测试集,然后使用 K折交叉验证 对模型进行调优,最终选择出性能最佳的模型

  
import shap  
explainer = shap.TreeExplainer(best_model)  
shap_values = explainer.shap_values(X_test)  
# 分别提取各个类别对应的shap值  
shap_values_0 = shap_values[:, :, 0]  
shap_values_1 = shap_values[:, :, 1]  
shap_values_2 = shap_values[:, :, 2]  
shap_values_3 = shap_values[:, :, 3]  
shap_values_4 = shap_values[:, :, 4]  
# 计算每个类别的特征贡献度  
importance_0 = np.abs(shap_values_0).mean(axis=0)  
importance_1 = np.abs(shap_values_1).mean(axis=0)  
importance_2 = np.abs(shap_values_2).mean(axis=0)  
importance_3 = np.abs(shap_values_3).mean(axis=0)  
importance_4 = np.abs(shap_values_4).mean(axis=0)  
importance_df = pd.DataFrame({  
    '类别0': importance_0,  
    '类别1': importance_1,  
    '类别2': importance_2,  
    '类别3': importance_3,  
    '类别4': importance_4  
}, index=X_test.columns)  
# 根据编码对应的类别  
type_mapping = {  
    0: 'MVT',  
    1: 'SEDEX',  
    2: 'VMS',  
    3: 'epithermal',  
    4: 'skarn'  
}  
  
importance_df.columns = [type_mapping[int(col.split('类别')[1])] for col in importance_df.columns]  
import seaborn as sns  
# 添加一列用于存储行的和  
importance_df['row_sum'] = importance_df.sum(axis=1)  
sorted_importance_df = importance_df.sort_values(by='row_sum', ascending=True)  
sorted_importance_df = sorted_importance_df.drop(columns=['row_sum'])  
elements = sorted_importance_df.index  
colors = sns.color_palette("Set2", n_colors=len(sorted_importance_df.columns))  
fig, ax = plt.subplots(figsize=(12, 6), dpi=1200)  
bottom = np.zeros(len(elements))  
for i, column in enumerate(sorted_importance_df.columns):  
    ax.barh(  
        sorted_importance_df.index,     # y轴的特征名称  
        sorted_importance_df[column],  # 当前类别的SHAP值  
        left=bottom,                   # 设置条形图的起始位置  
        color=colors[i],               # 使用调色板中的颜色  
        label=column                   # 为图例添加类别名称  
    )  
    # 更新底部位置,以便下一个条形图能够正确堆叠  
    bottom += sorted_importance_df[column]  
ax.set_xlabel('mean(SHAP value|)(average impact on model output magnitude)', fontsize=12)  
ax.set_ylabel('Features ', fontsize=12)  
ax.set_title('Feature Importance by Class', fontsize=15)  
ax.set_yticks(np.arange(len(elements)))  
ax.set_yticklabels(elements, fontsize=10)  
for i, el in enumerate(elements):  
    ax.text(bottom[i], i, ' ' + str(el), va='center', fontsize=9)  
ax.legend(title='Class', fontsize=10, title_fontsize=12)  
ax.set_yticks([])  # 移除y轴刻度  
ax.set_yticklabels([])  # 移除y轴刻度标签  
ax.set_ylabel('')  # 移除y轴标签  
plt.savefig("1.pdf", format='pdf', bbox_inches='tight', dpi=1200)  
ax.spines['top'].set_visible(False)  
ax.spines['right'].set_visible(False)  
plt.show()

picture.image

使用SHAP分析训练好的XGBoost模型在多分类任务中各特征对不同类别预测的重要性,并以堆叠水平条形图形式可视化每个特征在不同类别中的平均影响力

  
plt.figure(figsize=(14, 12))  # 设置图像大小,提升可读性(尤其是大字体场景)  
heatmap_ax = sns.heatmap(  
    importance_df.drop(['row_sum'], axis=1) ,# 删掉汇总列当然也可以不删除表示特征不区分类别整体的贡献   
    annot=True,                  # 在单元格中显示数值  
    fmt=".3f",                   # 数值格式保留三位小数  
    cmap="YlGnBu",               # 配色方案(可选:'viridis'、'plasma'、'Blues')  
    linewidths=0.5,              # 单元格之间添加细线  
    linecolor='lightgray',       # 单元格边线颜色  
    annot_kws={"size": 18}       # 热图数值字体大小设置为18  
)  
  
# 设置坐标轴刻度字体  
plt.xticks(fontsize=20, rotation=45, ha='right')  # X轴刻度字体大小和旋转角度  
plt.yticks(fontsize=20, rotation=0)               # Y轴刻度字体大小,水平显示  
  
# 设置坐标轴标签和标题  
plt.ylabel('Elements', fontsize=22)                  # Y轴标签  
plt.xlabel('Deposit Type ', fontsize=22)                # X轴标签  
plt.title('Heatmap of Element Importance by Deposit Type', fontsize=24, pad=20)  # 图标题,带间距  
# 设置颜色条(Colorbar)  
cbar = heatmap_ax.collections[0].colorbar  
cbar.set_label('Importance Value', fontsize=20)           # 颜色条标签字体  
cbar.ax.tick_params(labelsize=18)                  # 颜色条刻度字体大小  
plt.savefig("2.pdf", format='pdf', bbox_inches='tight', dpi=1200)  
plt.tight_layout()  
plt.show()

picture.image

使用Seaborn绘制一个展示各元素在不同矿床类型中重要性的热图,和上面的柱状图表达的含义是一致的都是对模型整体进行解读特征影响模型程度,只是可视化形式不一致

  
# 类别标签  
class_labels = ['MVT', 'SEDEX', 'VMS', 'epithermal', 'skarn']  
  
# 提取第一个样本的 SHAP 值,并为每个类别赋予独特的变量名  
shap_values_mvt = shap_values_0[0]     # 对应 'MVT' 类别的第一个样本 SHAP 值  
shap_values_sedex = shap_values_1[0]    # 对应 'SEDEX' 类别的第一个样本 SHAP 值  
shap_values_vms = shap_values_2[0]     # 对应 'VMS' 类别的第一个样本 SHAP 值  
shap_values_epithermal = shap_values_3[0]  # 对应 'epithermal' 类别的第一个样本 SHAP 值  
shap_values_skarn = shap_values_4[0]    # 对应 'skarn' 类别的第一个样本 SHAP 值  
  
# 合并为 DataFrame  
importance_df = pd.DataFrame({  
    'MVT': shap_values_mvt,  
    'SEDEX': shap_values_sedex,  
    'VMS': shap_values_vms,  
    'epithermal': shap_values_epithermal,  
    'skarn': shap_values_skarn  
}, index=X_test.columns)  
  
# 可选:添加行和列总和  
importance_df['row_sum'] = importance_df.sum(axis=1)  
  
# 画热图(删除汇总列 row_sum)  
plt.figure(figsize=(14, 12))  # 图像尺寸  
heatmap_ax = sns.heatmap(  
    importance_df.drop(['row_sum'], axis=1),  
    annot=True,  
    fmt=".3f",  
    cmap="YlGnBu",  
    linewidths=0.5,  
    linecolor='lightgray',  
    annot_kws={"size": 18}  
)  
  
# 坐标轴设置  
plt.xticks(fontsize=20, rotation=45, ha='right')  
plt.yticks(fontsize=20, rotation=0)  
plt.ylabel('Elements', fontsize=22)  
plt.xlabel('Deposit Type', fontsize=22)  
plt.title('Heatmap of SHAP Values for Sample 0 by Deposit Type', fontsize=24, pad=20)  
  
# 颜色条设置  
cbar = heatmap_ax.collections[0].colorbar  
cbar.set_label('SHAP Value', fontsize=20)  
cbar.ax.tick_params(labelsize=18)  
plt.savefig("3.pdf", format='pdf', bbox_inches='tight', dpi=1200)  
plt.tight_layout()  
plt.show()

picture.image

针对第一个测试样本,绘制一个热图,展示其各元素在不同矿床类型预测中的 SHAP 解释值,用于直观理解该样本特征对各类别判断的影响,如果利用力图或者瀑布图进行单样本解读绘制的话就可以对每个类别都有一个对应的可视化可能就会糅杂,热图就合并在一起进行展示了

✨ 该文章案例 ✨

picture.image

在上传至交流群的文件中,像往期文章一样,将对案例进行逐步分析,确保读者能够达到最佳的学习效果。内容都经过详细解读,帮助读者深入理解模型的实现过程和数据分析步骤,从而最大化学习成果。

同时,结合提供的免费AI聚合网站进行学习,能够让读者在理论与实践之间实现融会贯通,更加全面地掌握核心概念。

✨ 购买介绍 ✨

本节介绍到此结束,有需要学习数据分析和Python机器学习相关的朋友欢迎到淘宝店铺:Python机器学习AI,或添加作者微信deep_ML联系,购买作者的公众号合集。截至目前为止,合集已包含近300多篇文章,购买合集的同时,还将提供免费稳定的AI大模型使用,包括但不限于ChatGPT、Deepseek、Claude等。

更新的内容包含数据、代码、注释和参考资料。 作者仅分享案例项目,不提供额外的答疑服务。项目中将提供详细的代码注释和丰富的解读,帮助您理解每个步骤 。 购买前请咨询,避免不必要的问题。

✨ 群友反馈 ✨

picture.image

✨ 淘宝店铺 ✨

picture.image

请大家打开淘宝扫描上方的二维码,进入店铺,获取更多Python机器学习和AI相关的内容,或者添加作者微信deep_ML联系 避免淘宝客服漏掉信息 ,希望能为您的学习之路提供帮助!

往期推荐

机器学习在临床数据分析中的应用:从数据预处理到Web应用实现的完整流程教学

期刊配图:基于SHAP算法的驱动因子相互作用贡献矩阵图

期刊配图:PCA、t-SNE与UMAP三种降维方法简化高维数据的展示应用对比

XGBoost模型优化:基于相关系数剔除多重共线性与穷举法进行特征选择

Geographical-XGBoost:一种基于梯度提升树的空间局部回归的新集成模型实现

因果推断:利用EconML实现双重机器学习估计条件平均处理效应 (CATE)

期刊复现:基于部分线性模型的双重机器学习方法

期刊复现:基于XGBoost模型的网页工具SHAP力图解释单样本预测结果

期刊配图:nature cities通过ALE(累积局部效应)解析特征对模型影响

期刊复现:结合因果推断与SHAP可解释性的机器学习实现应用流程

期刊配图:一区SCI常用数据缺失率展示图可视化

因果推断:注册行为对后续消费影响的因果推断分析案例

nature communications:基于LightGBM与随机森林结合的多模型特征选择方法

因果推断与机器学习结合:探索酒店预订取消的影响因素

picture.image

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
火山引擎大规模机器学习平台架构设计与应用实践
围绕数据加速、模型分布式训练框架建设、大规模异构集群调度、模型开发过程标准化等AI工程化实践,全面分享如何以开发者的极致体验为核心,进行机器学习平台的设计与实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论