文献复现——优化SHAP依赖图拟合曲线与交点标注的新应用

机器学习关系型数据库云安全

picture.image

背景

在这篇文章中,将带读者深入探讨SHAP值解释图的优化与可视化手段,并结合之前的研究及应用——复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性,展示如何通过在图中引入拟合曲线以及标注SHAP值为0时的交点,进一步提升对机器学习模型解释性的理解,本文的灵感主要来源于对《建成环境对街道活力的非线性影响和交互效应》这篇研究的解读与延展

picture.image

picture.image

拟合曲线的引入与优化

为了更好地揭示特征与目标变量之间的复杂非线性关系,在SHAP散点图中引入了LOWESS拟合曲线,这条曲线是通过局部加权回归法生成的,能够平滑数据点之间的变化,帮助直观地捕捉数据的趋势走向

SHAP值为0时的交点标注

在对SHAP解释图的进一步优化中,特别添加了SHAP值为0时拟合曲线的交点标注。这一点非常重要,因为SHAP值为0时意味着该特征在该点附近对模型的预测结果没有显著影响。标注这一交点,可以帮助识别出特征值在哪些区间对目标变量的影响是从无到有或从正到负的转变

代码实现

数据集加载


          
import pandas as pd
          
import numpy as np
          
import matplotlib.pyplot as plt
          
from sklearn.model_selection import train_test_split
          
plt.rcParams['font.family'] = 'Times New Roman'
          
plt.rcParams['axes.unicode_minus'] = False
          
import warnings
          
warnings.filterwarnings("ignore")
          
df = pd.read_csv('Dataset.csv')
          
# 划分特征和目标变量
          
X = df.drop(['target'], axis=1)
          
y = df['target']
          
# 划分训练集和测试集
          
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 
          
                                                    random_state=42, stratify=df['target'])
          
df.head()
      

picture.image

首先,需要加载数据集并将其划分为特征 X 和目标变量 y,然后进行训练集和测试集的划分。目标变量是我们要预测的值,X 是输入的特征,这是一个分类任务,目标是预测患者是否患有心脏病。虽然是分类任务,但无论是分类问题还是回归问题,SHAP 依赖图的使用方式和原理是相同的,都可以用来解释模型中各个特征对预测结果的贡献

训练机器学习模型


          
from sklearn.ensemble import GradientBoostingClassifier
          
from sklearn.model_selection import GridSearchCV
          

          
# GBT模型参数
          
params_gbt = {
          
    'learning_rate': 0.02,            # 学习率,控制每一步的步长,用于防止过拟合。典型值范围:0.01 - 0.1
          
    'max_depth': 3,                   # 树的深度,控制模型复杂度
          
    'random_state': 42,               # 随机种子,用于重现模型的结果
          
    'subsample': 0.7,                 # 每次迭代时随机选择的样本比例,用于增加模型的泛化能力
          
}
          

          
# 初始化Gradient Boosting分类模型
          
model_gbt = GradientBoostingClassifier(**params_gbt)
          

          
# 定义参数网格,用于网格搜索
          
param_grid = {
          
    'n_estimators': [100, 200, 300],  # 树的数量
          
    'max_depth': [3, 4, 5],               # 树的深度
          
    'learning_rate': [0.01, 0.1],   # 学习率
          
}
          

          
# 使用GridSearchCV进行网格搜索和k折交叉验证
          
grid_search = GridSearchCV(
          
    estimator=model_gbt,
          
    param_grid=param_grid,
          
    scoring='neg_log_loss',  # 评价指标为负对数损失
          
    cv=5,                    # 5折交叉验证
          
    n_jobs=-1,               # 并行计算
          
    verbose=1                # 输出详细进度信息
          
)
          

          
# 训练模型
          
grid_search.fit(X_train, y_train)
          

          
# 使用最优参数训练模型
          
best_model = grid_search.best_estimator_
      

代码通过网格搜索 (GridSearchCV) 对Gradient Boosting分类模型的超参数进行优化,并通过5折交叉验证选择出最优的模型参数

shap值计算整理


          
import shap
          
explainer = shap.TreeExplainer(best_model)
          
# 计算shap值为numpy.array数组
          
shap_values_numpy = explainer.shap_values(X)
          

          
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
          
shap_values_df.head()
      

picture.image

计算模型的SHAP值,并将其转换为DataFrame格式,方便后续进行自定义绘图分析

基础绘图


          
# 绘制散点图,x轴是'age'特征,y轴是SHAP值
          
plt.figure(figsize=(6, 4),dpi=1200)
          
plt.scatter(df['age'], shap_values_df['age'], s=10)
          
# 添加shap=0的横线
          
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
          
plt.xlabel('Age', fontsize=12)
          
plt.ylabel('SHAP value for\nAge', fontsize=12) 
          
ax = plt.gca()
          
ax.spines['top'].set_visible(False)
          
ax.spines['right'].set_visible(False)
          
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight')
          
plt.show()
      

picture.image

绘制了一个基础的SHAP依赖图,其中x轴代表特征“age”(年龄),y轴代表该特征的SHAP值,即年龄对模型预测的影响大小,散点图展示了不同年龄的SHAP值,黑色虚线表示SHAP值为0的基准线,表示在该点年龄对预测没有显著正负影响,此图帮助直观地理解特征“age”对模型预测结果的影响方向和程度,当然更具体的解释参考文章——复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性

通过拟合曲线与交点标注绘图


          
import seaborn as sns
          
from scipy.optimize import fsolve
          

          
# 绘制散点图
          
plt.figure(figsize=(8, 6), dpi=300)
          
plt.scatter(df['age'], shap_values_df['age'], s=20, label='SHAP values', alpha=0.7)
          

          
# 添加LOWESS拟合曲线
          
sns.regplot(x=df['age'], y=shap_values_df['age'], scatter=False, lowess=True, color='lightcoral', label='LOWESS Curve')
          

          
# 使用 LOWESS 数据生成拟合曲线
          
lowess_data = sns.regplot(x=df['age'], y=shap_values_df['age'], scatter=False, lowess=True, color='lightcoral')
          
line = lowess_data.get_lines()[0]  # 拟合线条对象
          
x_fit = line.get_xdata()  # LOWESS 拟合线的 x 轴数据
          
y_fit = line.get_ydata()  # LOWESS 拟合线的 y 轴数据
          

          
# 找出所有与 y=0 相交的 x 值
          
def find_zero_crossings(x_fit, y_fit):
          
    crossings = []
          
    for i in range(1, len(y_fit)):
          
        if (y_fit[i-1] < 0 and y_fit[i] > 0) or (y_fit[i-1] > 0 and y_fit[i] < 0):
          
            # 使用插值法找到 x_fit 和 y_fit 中 y 值接近 0 的 x 值
          
            crossing = fsolve(lambda x: np.interp(x, x_fit, y_fit), x_fit[i])[0]
          
            crossings.append(crossing)
          
    return crossings
          

          
x_intercepts = find_zero_crossings(x_fit, y_fit)
          

          
# 在图中标注所有的 x_intercepts
          
for x_intercept in x_intercepts:
          
    plt.axvline(x=x_intercept, color='blue', linestyle='--', label=f'Intersection at Age = {x_intercept:.2f}')
          
    plt.text(x_intercept, 0.2, f'Age = {x_intercept:.2f}', color='blue', fontsize=10, verticalalignment='bottom')
          

          
# 添加shap=0的横线
          
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1, label='SHAP = 0')
          

          
# 添加图例
          
plt.legend()
          

          
# 设置标签和标题
          
plt.xlabel('Age', fontsize=12)
          
plt.ylabel('SHAP value for\nAge', fontsize=12)
          
# 去除上和右的边框线
          
ax = plt.gca()
          
ax.spines['top'].set_visible(False)
          
ax.spines['right'].set_visible(False)
          
# 保存并显示图像
          
plt.savefig("SHAP Dependence Plot_with_Multiple_Intersections.pdf", format='pdf', bbox_inches='tight')
          
plt.show()
      

picture.image

在这幅图中,通过LOWESS拟合曲线和SHAP解释图来深入分析年龄(Age)对模型预测结果的影响。下面着重解释拟合线与交点的含义:

LOWESS拟合曲线

LOWESS曲线(红色曲线)是局部加权回归曲线,它用来平滑数据中的非线性趋势。在这幅图中,它表示了年龄对目标变量的平均影响趋势,从曲线中可以看出,随着年龄的变化,SHAP值也随之波动,通过这条拟合曲线,可以识别出不同年龄区间对模型预测的不同贡献

  • 在年龄较低时(约40岁以下),SHAP值为负,表示年龄对预测结果的负向影响较为明显
  • 之后,随着年龄的增加,SHAP值开始逐渐上升,到了大约53岁附近时,SHAP值变为正,说明该特征对模型开始有正向的影响
  • 过了约64岁之后,SHAP值再次呈现下降趋势,表明年龄对预测的正向影响逐渐减弱甚至变为负向

交点

图中用蓝色虚线标注了两个交点,分别表示SHAP值曲线与y=0的交点,这两个交点表示特定年龄时,SHAP值为零,即在这些点上,年龄对模型的影响由负向或正向逐渐转变

  • 第一个交点(Age = 53.92):这是年龄从负向影响转变为正向影响的点,当年龄大于53.92时,SHAP值开始为正,意味着年龄对模型的正向贡献逐渐增大
  • 第二个交点(Age = 64.35):这是年龄从正向影响转为负向影响的点,当年龄超过64.35时,SHAP值再次变为负,说明此时年龄对模型的预测影响逐渐减弱

通过拟合曲线和交点,可以更直观地理解特征“年龄”对模型预测结果的非线性影响,尤其是这些交点,它们揭示了年龄在特定区间中对目标变量的关键变化点,有助于理解模型如何处理年龄这个特征,以及如何做出更精准的解释

多特征SHAP依赖图:通过拟合曲线与交点分析模型特征贡献


          
# 定义绘制 SHAP 依赖图的函数
          
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
          
    fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
          
    plt.subplots_adjust(hspace=0.4, wspace=0.4)
          
    
          
    # 循环绘制每个特征的 SHAP 依赖图
          
    for i, feature in enumerate(feature_list):
          
        row = i // 3  # 行号
          
        col = i % 3   # 列号
          
        ax = axs[row, col]
          
        
          
        # 绘制散点图,x轴是特征值,y轴是SHAP值
          
        ax.scatter(df[feature], shap_values_df[feature], s=10, alpha=0.7)
          

          
        # 添加 LOWESS 拟合曲线
          
        sns.regplot(x=df[feature], y=shap_values_df[feature], scatter=False, lowess=True, color='lightcoral', ax=ax)
          

          
        # 使用 LOWESS 数据生成拟合曲线
          
        lowess_data = sns.regplot(x=df[feature], y=shap_values_df[feature], scatter=False, lowess=True, color='lightcoral', ax=ax)
          
        line = lowess_data.get_lines()[0]  # 拟合线条对象
          
        x_fit = line.get_xdata()  # LOWESS 拟合线的 x 轴数据
          
        y_fit = line.get_ydata()  # LOWESS 拟合线的 y 轴数据
          

          
        # 找出所有与 y=0 相交的 x 值
          
        def find_zero_crossings(x_fit, y_fit):
          
            crossings = []
          
            for i in range(1, len(y_fit)):
          
                if (y_fit[i-1] < 0 and y_fit[i] > 0) or (y_fit[i-1] > 0 and y_fit[i] < 0):
          
                    crossing = fsolve(lambda x: np.interp(x, x_fit, y_fit), x_fit[i])[0]
          
                    crossings.append(crossing)
          
            return crossings
          

          
        x_intercepts = find_zero_crossings(x_fit, y_fit)
          

          
        # 在图中标注所有的 x_intercepts
          
        for x_intercept in x_intercepts:
          
            ax.axvline(x=x_intercept, color='blue', linestyle='--')  # 标注虚线
          
            ax.text(x_intercept, 0.1, f'{x_intercept:.2f}', color='black', fontsize=10, verticalalignment='bottom')  # 将文本标注颜色改为淡红色
          

          
        # 添加shap=0的横线
          
        ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
          
        
          
        # 设置x和y轴标签
          
        ax.set_xlabel(feature, fontsize=12)
          
        ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
          
        
          
        # 隐藏顶部和右侧的脊柱
          
        ax.spines['top'].set_visible(False)
          
        ax.spines['right'].set_visible(False)
          
    
          
    # 隐藏最后一个空图表的坐标轴 (若画布未关闭)
          
    axs[1, 2].axis('off')
          
    
          
    # 保存为 PDF 文件
          
    plt.savefig(file_name, format='pdf', bbox_inches='tight')
          
    plt.show()
          

          
# 使用函数绘制 age、trestbps、chol、thalach、oldpeak 的 shap 依赖图
          
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
          
plot_shap_dependence(feature_list, df, shap_values_df)
      

picture.image

通过绘制多个特征的SHAP依赖图,结合LOWESS拟合曲线与交点标注,分析各特征对模型预测的影响,当然,也可以采用其他拟合曲线,而不仅限于LOWESS,这里主要是基于参考文献中所使用的LOWESS拟合曲线进行分析,这里的解释同前文Age解释原理相同

往期推荐

SCI图表复现:整合数据分布与相关系数的高级可视化策略

复现顶刊Streamlit部署预测模型APP

树模型系列:如何通过XGBoost提取特征贡献度

SHAP进阶解析:机器学习、深度学习模型解释保姆级教程

特征选择:Lasso和Boruta算法的结合应用

从基础到进阶:优化SHAP力图,让样本解读更直观

SCI图表复现:优化SHAP特征贡献图展示更多模型细节

多模型中的特征贡献度比较与可视化图解

基于SHAP值的 BorutaShap 算法在特征选择中的应用与优化

复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性

用SHAP可视化解读数据特征的重要性:蜂巢图与特征关系图结合展示

picture.image

picture.image

picture.image

微信号|deep_ML

欢迎添加作者微信进入Python、ChatGPT群

进群请备注Python或AI进入相关群

无需科学上网、同步官网所有功能、使用无限制

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论