背景
在机器学习模型的解释中,SHAP是一种重要的工具,能够帮助深入了解各个特征对模型预测的贡献,对于连续变量,SHAP的可视化通常可以采用依赖图、主效应图或交互效应图,如图C所示,这些方式能够清晰展示变量的趋势和模型的响应,然而,对于分类变量的SHAP值,这种方法往往无法直观展现其分布和影响,为了解决这一问题,本文将介绍一种基于箱图的可视化方式,如图D所示,这种方法能够有效表达分类变量的SHAP值分布特点,为模型解释和特征分析提供新视角
代码实现
模型构建
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_excel('2024-12-10公众号Python机器学习AI.xlsx')
# 划分特征和目标变量
X = df.drop(['y'], axis=1)
y = df['y']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['y'])
from xgboost import XGBClassifier
# 使用 XGBoost 建模
model_xgb = XGBClassifier(n_estimators=100,
max_depth=3,
use_label_encoder=False,
random_state=8,
eval_metric='logloss') # 需要指定 eval_metric,避免警告
model_xgb.fit(X_train, y_train)
加载包含连续变量和多分类变量的数据集,通过划分训练集和测试集后,使用XGBoost建立一个基础的分类模型,用于预测目标变量(二分类)
shap值计算
import shap
explainer = shap.TreeExplainer(model_xgb)
shap_values = explainer.shap_values(X)
shap_values_df = pd.DataFrame(shap_values, columns=X.columns)
shap_values_df.head()
使用SHAP解释器对XGBoost模型进行解释,计算每个特征对预测的SHAP值,并将结果保存为DataFrame格式,以便进一步分析和可视化,这里针对的是整体数据集进行shap分析,一般只针对测试集即可,这里是作者测试集样本较少,所以对整体进行计算,可视化更美观
分类变量SHAP值的箱线图与散点图可视化
import seaborn as sns
x2_categories = X['X_2'].reset_index(drop=True) # 分类变量
shap_x2_values = shap_values_df['X_2'].reset_index(drop=True) # SHAP 值
plot_data = pd.DataFrame({
'Category': x2_categories,
'SHAP Value': shap_x2_values
})
plt.figure(figsize=(8, 6))
sns.boxplot(data=plot_data, x='Category', y='SHAP Value', showfliers=False, color='lightblue')
sns.stripplot(data=plot_data, x='Category', y='SHAP Value', color='lightgray', jitter=False, size=4)
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)
plt.title("X_2")
plt.ylabel("SHAP Value")
plt.xlabel("")
plt.savefig('1.pdf', format='pdf', bbox_inches='tight', dpi=1200)
plt.show()
将分类变量X_2的值和对应的SHAP值进行可视化,X_2是一个二分类变量(类别为0和1),通过箱线图和散点图结合的方式,展示分类变量X_2的SHAP值分布,用以观察不同类别对模型预测的影响。X轴表示X_2的两个类别(0和1),Y轴表示其SHAP值,反映该特征对预测结果的正负贡献。箱线图显示每个类别的SHAP值分布范围和中位数,散点图则展示每个样本的具体值及其密度。虚线(y=0)帮助判断SHAP值的正负影响。通过对比发现,类别0的SHAP值多为负且分布较广,而类别1的SHAP值集中在正向区域,说明分类变量X_2的类别0对预测结果主要起负向作用且影响波动较大,而类别1对预测结果主要起稳定的正向作用整体上,图形直观揭示了X_2对模型预测的重要性及其类别差异
分类变量SHAP值的分布及其目标变量的颜色标记可视化
在前面基础上增加了目标变量y的颜色标记,通过映射不同目标类别(0为紫色,1为黄色),将散点图按目标类别分组显示,用于明确颜色与目标类别的对应关系。这种改动使得图表不仅展示分类变量X_2的SHAP值分布,同时展示不同目标类别在SHAP值上的分布特征,增强图表的可读性和分析价值,也就是文献中的图D,代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系
分类变量X_13的SHAP值分布与样本量标注可视化
colors = ['#7D2E8E', '#00AEE9', '#F8766D']
plot_data = pd.DataFrame({
'Category': X['X_13'], # 替换为 X_13 类别
'SHAP Value': shap_values_df['X_13'] # 替换为 X_13 对应的 SHAP 值
})
plt.figure(figsize=(8, 6), dpi=1200)
sns.boxplot(
x='Category', y='SHAP Value', data=plot_data, width=0.4, palette=colors,
boxprops=dict(facecolor='white', edgecolor='black', linewidth=1.2), # 白色填充
medianprops=dict(color='black', linewidth=1.5), # 设置中位数线样式
showfliers=False # 隐藏离群值
)
# 绘制带有分组颜色的散点图
sns.stripplot(x='Category', y='SHAP Value', data=plot_data, jitter=True, palette=colors, size=4, alpha=0.8)
# 计算每个类别的样本量
sample_counts = plot_data['Category'].value_counts().sort_index()
for i, (label, count) in enumerate(sample_counts.items()):
plt.text(i, plot_data['SHAP Value'].min() - 0.7, f'n = {count}', ha='center', va='center', fontsize=10, style='italic')
# 添加 x=0 的灰色竖线
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)
plt.xlabel('')
plt.ylabel('SHAP Value')
plt.title('X_13')
plt.savefig('3.pdf', format='pdf', bbox_inches='tight')
plt.show()
与前面图一的区别在于,这里针对三分类变量X_13绘制SHAP值分布图,而不是二分类变量,新增三种颜色来区分不同类别的箱线图和散点图,并在图中标注每个类别的样本数量,此外,箱线图的样式更为细致,包括白色填充和黑色边框,以及标注中位数的黑色线条。这种改进更直观地展示三分类变量的SHAP值分布特征及其样本量差异,适用于多分类变量的分析场景
分类变量X_13的SHAP值分布及组间显著性分析
图4在图3的基础上增加了Mann-Whitney U检验和显著性标注,直观展示分类变量X_13不同类别之间SHAP值分布差异的统计显著性,为模型解释提供了更深入的统计支撑,如果p值很小(如图中p = 7.45e-46),说明该类别之间的SHAP值分布差异是显著的,意味着模型对于这些类别的预测影响有明确的区别,例如类别2对模型的正向贡献显著高于类别0和1,代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系
往期推荐
复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性
复现 Nature 图表——基于PCA的高维数据降维与可视化实践及其扩展
复现Nature图表——基于PCA降维与模型预测概率的分类效果可视化
如果你对类似于这样的文章感兴趣。
欢迎关注、点赞、转发~
个人观点,仅供参考