背景
在上一篇推文中,介绍了该算法在回归模型上的一个实现——Nature新算法:准确的小数据预测与表格基础模型TabPFN回归实现及其模型解释,本次实现,将TabPFN模型应用于模拟的二分类不平衡数据集
这个二分类数据集具有明显的不平衡性,常规算法可能在此类数据上表现不佳,对比一下在默认参数下该模型与XGBoost模型的一个性能差异。接下来详细讲解如何使用TabPFN进行分类模型的训练与预测,并提供对模型的可解释性分析
代码实现
数据读取分析
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
# 忽略所有警告
warnings.filterwarnings("ignore")
df = pd.read_excel('2025-2-14公众号Python机器学习AI.xlsx')
value_counts = df['y'].value_counts()
plt.figure(figsize=(6, 4), dpi=1200)
value_counts.plot(kind='bar', color=['skyblue', 'salmon'])
plt.title('Distribution of Survival vs Death', fontsize=16)
plt.xlabel('Survival Status', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.xticks([0, 1], ['Survival (0)', 'Death (1)'], rotation=0)
plt.savefig("6.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
读取数据集,绘制柱状图展示目标变量(y)中“生存(0)”与“死亡(1)”的分布情况。通过观察柱状图,可以发现数据样本存在不平衡,0类(生存)样本的数量大约是1类(死亡)样本的两倍
训练集、测试集划分
from sklearn.model_selection import train_test_split
# 划分特征和目标变量
X = df.drop(['y'], axis=1)
y = df['y']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
random_state=42, stratify=df['y'])
将数据集划分为特征(X)和目标变量(y),并通过train_test_split函数将数据分为训练集和测试集,其中测试集占30%,并保持目标变量y的分布一致
XGBoost模型构建
import xgboost as xgb
# 初始化XGBoost分类模型
model_xgb = xgb.XGBClassifier()
# 训练模型
model_xgb.fit(X_train, y_train)
# 进行预测
y_pred = model_xgb.predict(X_test)
# 输出模型报告,查看评价指标
print(classification_report(y_test, y_pred))
初始化一个XGBoost分类模型,使用训练集数据进行训练,然后在测试集上进行预测,并输出分类报告以查看模型的评价指标,从分类报告来看,在XGBoost模型的默认参数下,模型对于类别0(生存)的预测较好,但对于类别1(死亡)的召回率较低(0.56),显示出一定的过拟合现象,模型在训练集上表现较好,而在测试集上未能有效泛化
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 绘制混淆矩阵的热力图
plt.figure(figsize=(8, 6),dpi=1200)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=model_xgb.classes_, yticklabels=model_xgb.classes_)
plt.xlabel('Predicted Labels', fontsize=12, fontweight='bold')
plt.ylabel('True Labels', fontsize=12, fontweight='bold')
plt.title('Confusion Matrix for XGBoost Classifier', fontsize=14, fontweight='bold')
plt.savefig("3.pdf", format='pdf', bbox_inches='tight')
plt.show()
# 计算训练集和测试集上的预测概率
y_train_pred_prob = model_xgb.predict_proba(X_train)[:, 1] # 获取正类的概率
y_test_pred_prob = model_xgb.predict_proba(X_test)[:, 1] # 获取正类的概率
# 计算ROC曲线
fpr_train, tpr_train, _ = roc_curve(y_train, y_train_pred_prob)
fpr_test, tpr_test, _ = roc_curve(y_test, y_test_pred_prob)
# 计算AUC
roc_auc_train = auc(fpr_train, tpr_train)
roc_auc_test = auc(fpr_test, tpr_test)
# 绘制ROC曲线
plt.figure(figsize=(8, 6),dpi=1200)
plt.plot(fpr_train, tpr_train, color='blue', label='Train ROC curve (AUC = %0.2f)' % roc_auc_train)
plt.plot(fpr_test, tpr_test, color='red', label='Test ROC curve (AUC = %0.2f)' % roc_auc_test)
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
# 设置加粗的字体
plt.xlabel('False Positive Rate', fontsize=12, fontweight='bold')
plt.ylabel('True Positive Rate', fontsize=12, fontweight='bold')
plt.title('ROC Curve for XGBoost Classifier', fontsize=14, fontweight='bold')
plt.legend(loc='lower right', prop={'weight': 'bold'})
plt.savefig("4.pdf", format='pdf', bbox_inches='tight')
plt.show()
接下来也计算并绘制了XGBoost模型的混淆矩阵和ROC曲线,展示了模型在训练集和测试集上的表现。通过混淆矩阵可以看到XGBoost在测试集上的分类结果不均衡,特别是在类别1的预测上表现较差;在ROC曲线中,训练集的AUC值为1.00,而测试集的AUC为0.84,进一步表明模型在训练集上过拟合。接下来,将引入TabPFN分类模型,观察它在该任务上的表现,以期获得更好的泛化能力
TabPFN
模型构建
from tabpfn import TabPFNClassifier
# 初始化TabPFN分类模型
model = TabPFNClassifier()
# 训练模型
model.fit(X_train, y_train)
from sklearn.metrics import classification_report
# 预测测试集
y_pred = model.predict(X_test)
# 输出模型报告,查看评价指标
print(classification_report(y_test, y_pred))
使用TabPFN分类模型对数据进行训练和测试,并输出分类报告。与XGBoost模型相比,TabPFN模型在默认参数下表现更好,分类报告中的准确率(0.83)和F1分数(类别1为0.76)均优于XGBoost。接下来,我们将通过混淆矩阵和ROC曲线进一步评估TabPFN模型的性能,并与XGBoost进行对比
通过混淆矩阵和ROC曲线可以发现,在默认参数下,TabPFN模型的表现明显优于XGBoost,并且相较于XGBoost,过拟合得到了显著的优化。需要注意的是,这里的混淆矩阵和ROC曲线并非直接通过库函数实现,而是经过了美化处理,以达到期刊配图的标准,完整 代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系
接下来,我们将利用LIME(局部可解释模型-agnostic解释器)对TabPFN模型进行单样本解释,LIME可以帮助揭示该黑盒模型在特定样本上的预测依据,从而增强模型的可解释性,具体的可以参考往期文章——期刊配图:模型可解释性工具LIME的实现及其优劣点
LIME模型解释
from lime.lime_tabular import LimeTabularExplainer
# 初始化 LIME 解释器
explainer = LimeTabularExplainer(
training_data=np.array(X_train), # 训练数据
feature_names=X.columns.tolist(), # 特征名称
class_names=['survival', 'death'], # 分类标签(根据实际情况修改)
mode='classification' # 模式为分类
)
初始化LIME解释器,用于对TabPFN分类模型进行解释,传入训练数据、特征名称和分类标签,设置模式为分类任务
# 从测试集中选取一个样本
test_instance = X_test.iloc[0]
从测试集中选取第一个样本,供LIME解释器使用进行单样本解释
# 生成样本解释
exp = explainer.explain_instance(
data_row=test_instance, # 测试样本数据
predict_fn=model.predict_proba # 使用模型的预测概率方法
) # num_features=? 通过设置 num_features 参数来控制解释时最多考虑多少个特征
使用LIME解释器生成对选定测试样本的解释,通过调用模型的预测概率方法,并可以通过num_features参数设置在解释时最多考虑的特征数量,如果没有指定num_features参数,LIME会自动选择并考虑最重要的特征进行解释
# 显示解释
exp.show_in_notebook(show_table=True) # 在 Notebook 中显示解释表
exp.save_to_file("lime_explanation.html") # 保存解释到 HTML 文件
使用LIME解释器在Jupyter Notebook中显示模型对单个样本的解释结果,并将该解释保存为HTML文件。结果展示了各个特征的阈值条件和对应的特征值,同时显示了“生存”和“死亡”预测概率以及影响预测的关键特征
exp_data = exp.as_list()
lime_df = pd.DataFrame(exp_data, columns=['Feature', 'Contribution'])
# 按贡献值排序
lime_df = lime_df.sort_values(by='Contribution', ascending=False)
# 绘制条形图,正贡献为绿色,负贡献为红色
plt.barh(lime_df['Feature'], lime_df['Contribution'], color=lime_df['Contribution'].apply(lambda x: 'green' if x > 0 else 'red'))
plt.xlabel("Contribution", fontsize=14, fontweight='bold')
plt.title("Local Explanation for Class Death", fontsize=16, fontweight='bold')
plt.xticks(fontsize=12, fontweight='bold')
plt.yticks(fontsize=12, fontweight='bold')
plt.savefig("5.pdf", format='pdf', bbox_inches='tight',dpi=1200)
plt.show()
将LIME的解释结果转换为DataFrame,并按特征的贡献值排序,绘制了一个条形图,展示了每个特征对“死亡”类别预测的贡献,正贡献用绿色表示,负贡献用红色表示
通过LIME解释图和条形图可以看到,针对该测试样本,特征X_9 <= 0.00对预测“死亡”类别的贡献最大,其负贡献较大(-0.26),而X_39的特征值较高(482.00)对“死亡”类别有正贡献(0.05)。其他特征如X_38 > 25.00和X_30 <= 12.00等都对模型的输出产生了一定影响,但贡献较小,整体来看,模型将“死亡”类的预测归因于特定的特征阈值条件,这些特征值和预测概率被图表直观展示出来
往期推荐
期刊配图:SHAP可视化改进依赖图+拟合线+边缘密度+分组对比
期刊配图:SHAP蜂巢图与柱状图多维组合解读特征对模型的影响
期刊配图:回归模型对比如何精美可视化训练集与测试集的评价指标
期刊配图:如何同时可视化多个回归模型在训练集与测试集上的预测效果
期刊配图:SHAP可视化进阶蜂巢图与特征重要性环形图的联合展示方法
如果你对类似于这样的文章感兴趣。
欢迎关注、点赞、转发~
个人观点,仅供参考