复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性

大模型机器学习关系型数据库

picture.image

背景

在机器学习领域,理解各个特征对模型输出的贡献至关重要,尤其是在像环境科学和生物学这样的重要领域中,SHAP是一种强大的解释工具,能够帮助直观地展示特征对模型预测结果的影响,一项研究《基于可解释机器学习模型的浮游植物生物量预测及关键影响因素识别》中,研究人员使用了 SHAP 依赖图来可视化环境因素如何影响模型预测

picture.image

picture.image

本文将通过医学数据,使用 Python 演示如何复现 SHAP 依赖图,并详细解释连续性特征对模型预测结果的影响

什么是 SHAP 依赖图?

SHAP 依赖图用于可视化单个特征对机器学习模型预测结果的影响,具体来说,x 轴是特征值,y 轴是 SHAP 值(度量特征对预测结果的重要性),这些图可以直观地显示出某个特征是对模型预测起正向还是负向作用

代码实现

数据集加载


          
import pandas as pd
          
import numpy as np
          
import matplotlib.pyplot as plt
          
from sklearn.model_selection import train_test_split
          
plt.rcParams['font.family'] = 'Times New Roman'
          
plt.rcParams['axes.unicode_minus'] = False
          
import warnings
          
warnings.filterwarnings("ignore")
          
df = pd.read_csv('Dataset.csv')
          
# 划分特征和目标变量
          
X = df.drop(['target'], axis=1)
          
y = df['target']
          
# 划分训练集和测试集
          
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 
          
                                                    random_state=42, stratify=df['target'])
          
df.head()
      

picture.image

首先,需要加载数据集并将其划分为特征 X 和目标变量 y,然后进行训练集和测试集的划分。目标变量是我们要预测的值,X 是输入的特征,这是一个分类任务,目标是预测患者是否患有心脏病。虽然是分类任务,但无论是分类问题还是回归问题,SHAP 依赖图的使用方式和原理是相同的,都可以用来解释模型中各个特征对预测结果的贡献

训练机器学习模型


          
from sklearn.ensemble import GradientBoostingClassifier
          
from sklearn.model_selection import GridSearchCV
          

          
# GBT模型参数
          
params_gbt = {
          
    'learning_rate': 0.02,            # 学习率,控制每一步的步长,用于防止过拟合。典型值范围:0.01 - 0.1
          
    'max_depth': 3,                   # 树的深度,控制模型复杂度
          
    'random_state': 42,               # 随机种子,用于重现模型的结果
          
    'subsample': 0.7,                 # 每次迭代时随机选择的样本比例,用于增加模型的泛化能力
          
}
          

          
# 初始化Gradient Boosting分类模型
          
model_gbt = GradientBoostingClassifier(**params_gbt)
          

          
# 定义参数网格,用于网格搜索
          
param_grid = {
          
    'n_estimators': [100, 200, 300],  # 树的数量
          
    'max_depth': [3, 4, 5],               # 树的深度
          
    'learning_rate': [0.01, 0.1],   # 学习率
          
}
          

          
# 使用GridSearchCV进行网格搜索和k折交叉验证
          
grid_search = GridSearchCV(
          
    estimator=model_gbt,
          
    param_grid=param_grid,
          
    scoring='neg_log_loss',  # 评价指标为负对数损失
          
    cv=5,                    # 5折交叉验证
          
    n_jobs=-1,               # 并行计算
          
    verbose=1                # 输出详细进度信息
          
)
          

          
# 训练模型
          
grid_search.fit(X_train, y_train)
          

          
# 使用最优参数训练模型
          
best_model = grid_search.best_estimator_
      

这里使用了梯度提升树(GBT),这是一个强大且常用的机器学习算法,通过网格搜索进行参数优化

计算 SHAP 值


          
import shap
          
explainer = shap.TreeExplainer(best_model)
          
# 计算shap值为numpy.array数组
          
shap_values_numpy = explainer.shap_values(X)
          
# 计算shap值为Explanation格式
          
shap_values_Explanation = explainer(X)
      

模型训练完毕后,可以使用 shap 包来计算 SHAP 值,SHAP 值用于衡量特定特征对模型输出的影响,这里分别通过 explainer.shap_values(X) 计算 SHAP 值为数组格式以便自定义绘制,和通过 explainer(X) 计算为 Explanation 格式,直接使用 SHAP 自带的绘图函数进行可视化

SHAP自带绘图函数实现依赖图

默认参数下绘制


          
# 绘制 'age' 特征的SHAP依赖图
          
shap.dependence_plot('age', shap_values_Explanation.values, X, show=False)
          
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight',dpi=1200)
      

picture.image

图展示了 age(年龄) 特征对模型预测结果的 SHAP 值的依赖关系,说明不同年龄段如何影响模型的预测

  • X 轴(age): 表示年龄的取值范围,从 30 到 75 岁
  • Y 轴(SHAP value for age): 表示年龄对模型预测的影响。 SHAP 值为正时,表示该年龄段增加了模型预测的概率; SHAP 值为负时,表示该年龄段降低了预测的概率

从图中可以看到:

  • 年龄在 50 到 60 岁之间 对模型预测结果有显著的正面影响,SHAP 值较高,说明模型在这个年龄段倾向于预测目标事件的发生
  • 70 岁左右,SHAP 值开始变为负数,意味着在这个年龄段,模型预测发生的概率降低
  • 颜色代表了 thal(地中海贫血类型) 这一交互特征的影响,红色表示更高的值,蓝色表示较低的值, 可以看到,thal 的不同取值对 SHAP 值的分布有一定影响,尤其是在 SHAP 值较大的区域,红色点较为集中

展示了年龄对模型预测的非线性影响,同时揭示了另一个特征(thal)如何与年龄共同作用,影响预测结果,然而,与文献中的图表样式相比,仍存在一些细微的差别

绘制无颜色条的年龄 SHAP 依赖图


          
# 绘制 'age' 特征的 SHAP 依赖图,不显示颜色条
          
shap.dependence_plot('age', shap_values_Explanation.values, X, interaction_index=None, show=False)
          
# 添加 SHAP=0 的横线
          
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
          
plt.savefig("SHAP Dependence Plot_2.pdf", format='pdf',bbox_inches='tight',dpi=1200)
          
plt.show()
      

picture.image

在这里,通过设置 interaction_index=None 可以关闭颜色条,不显示交互特征的影响。 不过,该函数目前没有内置参数可以直接在 SHAP 值为 0 的位置添加一条横线。 为了实现这一功能,可以利用 matplotlib 的 plt.axhline() 方法,在绘制依赖图后手动添加横线

接下来,还可以通过 explainer.shap_values(X) 格式绘制这个shap依赖图,以便实现自定义绘图

自定义绘图

将 SHAP 值转换为 DataFrame 格式以便于自定义绘图


          
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
          
shap_values_df.head()
      

picture.image

单个shap依赖图绘制


          
# 绘制散点图,x轴是'age'特征,y轴是SHAP值
          
plt.figure(figsize=(6, 4),dpi=1200)
          
plt.scatter(df['age'], shap_values_df['age'], s=10)
          
# 添加shap=0的横线
          
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
          
plt.xlabel('Age', fontsize=12)
          
plt.ylabel('SHAP value for\nAge', fontsize=12) 
          
ax = plt.gca()
          
ax.spines['top'].set_visible(False)
          
ax.spines['right'].set_visible(False)
          
plt.savefig("SHAP Dependence Plot_3.pdf", format='pdf',bbox_inches='tight')
          
plt.show()
      

picture.image

代码生成一个 SHAP 值依赖图,其中展示了特征 age 对模型输出的贡献,同时对图表进行了一些格式上的优化,比如隐藏不必要的边框线条、在 SHAP=0 处添加一条基准线,并最终将图像保存为高分辨率的 PDF 文件。相比于直接使用 shap.dependence_plot() 的默认作图方式,这种方法提供了更高的灵活性,特别是在定制化绘图方面,可以根据不同场景、需求对图表进行高度定制,从而提高可视化的效果和表达的准确性

多个sha

p依赖图绘制


          
# 定义绘制 SHAP 依赖图的函数
          
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
          
    fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
          
    plt.subplots_adjust(hspace=0.4, wspace=0.4)
          
    
          
    # 循环绘制每个特征的 SHAP 依赖图
          
    for i, feature in enumerate(feature_list):
          
        row = i // 3  # 行号
          
        col = i % 3   # 列号
          
        ax = axs[row, col]
          
        
          
        # 绘制散点图,x轴是特征值,y轴是SHAP值
          
        ax.scatter(df[feature], shap_values_df[feature], s=10)
          
        
          
        # 添加shap=0的横线
          
        ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
          
        
          
        # 设置x和y轴标签
          
        ax.set_xlabel(feature, fontsize=12)
          
        ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
          
        
          
        # 隐藏顶部和右侧的脊柱
          
        ax.spines['top'].set_visible(False)
          
        ax.spines['right'].set_visible(False)
          
    # 隐藏最后一个空图表的坐标轴 (若画布未关闭)
          
    axs[1, 2].axis('off')
          
    plt.savefig(file_name, format='pdf', bbox_inches='tight')
          
    plt.show()
          

          
# 使用函数绘制age、trestbps、chol、thalach、oldpeak的shap依赖图
          
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
          
plot_shap_dependence(feature_list, df, shap_values_df)
      

picture.image

这段代码定义一个函数 plot_shap_dependence,用于绘制给定特征列表的 SHAP 依赖图,生成 2 行 3 列的图表布局,并在 SHAP=0 处添加基准线,最后保存为高分辨率 PDF,该图的样式基本上与文献中的 SHAP 依赖图形式一致,包括散点图、SHAP 值为 0 的基准线、去掉顶部和右侧脊柱的简洁图形设计等

往期推荐

SCI图表复现:整合数据分布与相关系数的高级可视化策略

复现顶刊Streamlit部署预测模型APP

树模型系列:如何通过XGBoost提取特征贡献度

SHAP进阶解析:机器学习、深度学习模型解释保姆级教程

特征选择:Lasso和Boruta算法的结合应用

从基础到进阶:优化SHAP力图,让样本解读更直观

SCI图表复现:优化SHAP特征贡献图展示更多模型细节

多模型中的特征贡献度比较与可视化图解

基于SHAP值的 BorutaShap 算法在特征选择中的应用与优化

用SHAP可视化解读数据特征的重要性:蜂巢图与特征关系图结合展示

picture.image

picture.image

picture.image

微信号|deep_ML

欢迎添加作者微信进入Python、ChatGPT群

进群请备注Python或AI进入相关群

无需科学上网、同步官网所有功能、使用无限制

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论