Nature新算法:准确的小数据预测与表格基础模型TabPFN分类实现及其模型解释

机器学习算法数据库

picture.image

背景

picture.image

在上一篇推文中,介绍了该算法在回归模型上的一个实现——Nature新算法:准确的小数据预测与表格基础模型TabPFN回归实现及其模型解释,本次实现,将TabPFN模型应用于模拟的二分类不平衡数据集

这个二分类数据集具有明显的不平衡性,常规算法可能在此类数据上表现不佳,对比一下在默认参数下该模型与XGBoost模型的一个性能差异。接下来详细讲解如何使用TabPFN进行分类模型的训练与预测,并提供对模型的可解释性分析

代码实现

数据读取分析


          
import pandas as pd
          
import numpy as np
          
import matplotlib.pyplot as plt
          
plt.rcParams['font.family'] = 'Times New Roman'
          
plt.rcParams['axes.unicode_minus'] = False
          
import warnings
          
# 忽略所有警告
          
warnings.filterwarnings("ignore")
          
df = pd.read_excel('2025-2-14公众号Python机器学习AI.xlsx')
          

          
value_counts = df['y'].value_counts()
          
plt.figure(figsize=(6, 4), dpi=1200)
          
value_counts.plot(kind='bar', color=['skyblue', 'salmon'])
          
plt.title('Distribution of Survival vs Death', fontsize=16)
          
plt.xlabel('Survival Status', fontsize=12)
          
plt.ylabel('Count', fontsize=12)
          
plt.xticks([0, 1], ['Survival (0)', 'Death (1)'], rotation=0)
          
plt.savefig("6.pdf", format='pdf', bbox_inches='tight')
          
plt.tight_layout()
          
plt.show()
      

picture.image

读取数据集,绘制柱状图展示目标变量(y)中“生存(0)”与“死亡(1)”的分布情况。通过观察柱状图,可以发现数据样本存在不平衡,0类(生存)样本的数量大约是1类(死亡)样本的两倍

训练集、测试集划分


          
from sklearn.model_selection import train_test_split
          
# 划分特征和目标变量
          
X = df.drop(['y'], axis=1)
          
y = df['y']
          
# 划分训练集和测试集
          
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, 
          
                                                    random_state=42, stratify=df['y'])
      

将数据集划分为特征(X)和目标变量(y),并通过train_test_split函数将数据分为训练集和测试集,其中测试集占30%,并保持目标变量y的分布一致

XGBoost模型构建


          
import xgboost as xgb
          
# 初始化XGBoost分类模型
          
model_xgb = xgb.XGBClassifier()
          

          
# 训练模型
          
model_xgb.fit(X_train, y_train)
          

          
# 进行预测
          
y_pred = model_xgb.predict(X_test)
          

          
# 输出模型报告,查看评价指标
          
print(classification_report(y_test, y_pred))
      

picture.image

初始化一个XGBoost分类模型,使用训练集数据进行训练,然后在测试集上进行预测,并输出分类报告以查看模型的评价指标,从分类报告来看,在XGBoost模型的默认参数下,模型对于类别0(生存)的预测较好,但对于类别1(死亡)的召回率较低(0.56),显示出一定的过拟合现象,模型在训练集上表现较好,而在测试集上未能有效泛化


          
# 计算混淆矩阵
          
cm = confusion_matrix(y_test, y_pred)
          

          
# 绘制混淆矩阵的热力图
          
plt.figure(figsize=(8, 6),dpi=1200)
          
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=model_xgb.classes_, yticklabels=model_xgb.classes_)
          
plt.xlabel('Predicted Labels', fontsize=12, fontweight='bold')
          
plt.ylabel('True Labels', fontsize=12, fontweight='bold')
          
plt.title('Confusion Matrix for XGBoost Classifier', fontsize=14, fontweight='bold')
          
plt.savefig("3.pdf", format='pdf', bbox_inches='tight')
          
plt.show()
      

picture.image


          
# 计算训练集和测试集上的预测概率
          
y_train_pred_prob = model_xgb.predict_proba(X_train)[:, 1]  # 获取正类的概率
          
y_test_pred_prob = model_xgb.predict_proba(X_test)[:, 1]    # 获取正类的概率
          

          
# 计算ROC曲线
          
fpr_train, tpr_train, _ = roc_curve(y_train, y_train_pred_prob)
          
fpr_test, tpr_test, _ = roc_curve(y_test, y_test_pred_prob)
          

          
# 计算AUC
          
roc_auc_train = auc(fpr_train, tpr_train)
          
roc_auc_test = auc(fpr_test, tpr_test)
          

          
# 绘制ROC曲线
          
plt.figure(figsize=(8, 6),dpi=1200)
          
plt.plot(fpr_train, tpr_train, color='blue', label='Train ROC curve (AUC = %0.2f)' % roc_auc_train)
          
plt.plot(fpr_test, tpr_test, color='red', label='Test ROC curve (AUC = %0.2f)' % roc_auc_test)
          
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
          

          
# 设置加粗的字体
          
plt.xlabel('False Positive Rate', fontsize=12, fontweight='bold')
          
plt.ylabel('True Positive Rate', fontsize=12, fontweight='bold')
          
plt.title('ROC Curve for XGBoost Classifier', fontsize=14, fontweight='bold')
          
plt.legend(loc='lower right', prop={'weight': 'bold'})
          
plt.savefig("4.pdf", format='pdf', bbox_inches='tight')
          
plt.show()
      

picture.image

接下来也计算并绘制了XGBoost模型的混淆矩阵和ROC曲线,展示了模型在训练集和测试集上的表现。通过混淆矩阵可以看到XGBoost在测试集上的分类结果不均衡,特别是在类别1的预测上表现较差;在ROC曲线中,训练集的AUC值为1.00,而测试集的AUC为0.84,进一步表明模型在训练集上过拟合。接下来,将引入TabPFN分类模型,观察它在该任务上的表现,以期获得更好的泛化能力

TabPFN

模型构建


          
from tabpfn import TabPFNClassifier
          

          
# 初始化TabPFN分类模型
          
model = TabPFNClassifier()
          

          
# 训练模型
          
model.fit(X_train, y_train)
          

          
from sklearn.metrics import classification_report
          
# 预测测试集
          
y_pred = model.predict(X_test)
          

          
# 输出模型报告,查看评价指标
          
print(classification_report(y_test, y_pred))
      

picture.image

使用TabPFN分类模型对数据进行训练和测试,并输出分类报告。与XGBoost模型相比,TabPFN模型在默认参数下表现更好,分类报告中的准确率(0.83)和F1分数(类别1为0.76)均优于XGBoost。接下来,我们将通过混淆矩阵和ROC曲线进一步评估TabPFN模型的性能,并与XGBoost进行对比

picture.image

picture.image

通过混淆矩阵和ROC曲线可以发现,在默认参数下,TabPFN模型的表现明显优于XGBoost,并且相较于XGBoost,过拟合得到了显著的优化。需要注意的是,这里的混淆矩阵和ROC曲线并非直接通过库函数实现,而是经过了美化处理,以达到期刊配图的标准,完整 代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系

接下来,我们将利用LIME(局部可解释模型-agnostic解释器)对TabPFN模型进行单样本解释,LIME可以帮助揭示该黑盒模型在特定样本上的预测依据,从而增强模型的可解释性,具体的可以参考往期文章——期刊配图:模型可解释性工具LIME的实现及其优劣点

LIME模型解释


          
from lime.lime_tabular import LimeTabularExplainer
          
# 初始化 LIME 解释器
          
explainer = LimeTabularExplainer(
          
    training_data=np.array(X_train),         # 训练数据
          
    feature_names=X.columns.tolist(),        # 特征名称
          
    class_names=['survival', 'death'],              # 分类标签(根据实际情况修改)
          
    mode='classification'                    # 模式为分类
          
)
      

初始化LIME解释器,用于对TabPFN分类模型进行解释,传入训练数据、特征名称和分类标签,设置模式为分类任务


          
# 从测试集中选取一个样本
          
test_instance = X_test.iloc[0]
      

从测试集中选取第一个样本,供LIME解释器使用进行单样本解释


          
# 生成样本解释
          
exp = explainer.explain_instance(
          
    data_row=test_instance,                  # 测试样本数据
          
    predict_fn=model.predict_proba           # 使用模型的预测概率方法
          
)  # num_features=? 通过设置 num_features 参数来控制解释时最多考虑多少个特征
      

使用LIME解释器生成对选定测试样本的解释,通过调用模型的预测概率方法,并可以通过num_features参数设置在解释时最多考虑的特征数量,如果没有指定num_features参数,LIME会自动选择并考虑最重要的特征进行解释


          
# 显示解释
          
exp.show_in_notebook(show_table=True)        # 在 Notebook 中显示解释表
          
exp.save_to_file("lime_explanation.html")    # 保存解释到 HTML 文件
      

picture.image

使用LIME解释器在Jupyter Notebook中显示模型对单个样本的解释结果,并将该解释保存为HTML文件。结果展示了各个特征的阈值条件和对应的特征值,同时显示了“生存”和“死亡”预测概率以及影响预测的关键特征


          
exp_data = exp.as_list()
          
lime_df = pd.DataFrame(exp_data, columns=['Feature', 'Contribution'])
          

          
# 按贡献值排序
          
lime_df = lime_df.sort_values(by='Contribution', ascending=False)
          

          
# 绘制条形图,正贡献为绿色,负贡献为红色
          
plt.barh(lime_df['Feature'], lime_df['Contribution'], color=lime_df['Contribution'].apply(lambda x: 'green' if x > 0 else 'red'))
          
plt.xlabel("Contribution", fontsize=14, fontweight='bold')
          
plt.title("Local Explanation for Class Death", fontsize=16, fontweight='bold')
          
plt.xticks(fontsize=12, fontweight='bold')
          
plt.yticks(fontsize=12, fontweight='bold')
          
plt.savefig("5.pdf", format='pdf', bbox_inches='tight',dpi=1200)
          
plt.show()
      

picture.image

将LIME的解释结果转换为DataFrame,并按特征的贡献值排序,绘制了一个条形图,展示了每个特征对“死亡”类别预测的贡献,正贡献用绿色表示,负贡献用红色表示

通过LIME解释图和条形图可以看到,针对该测试样本,特征X_9 <= 0.00对预测“死亡”类别的贡献最大,其负贡献较大(-0.26),而X_39的特征值较高(482.00)对“死亡”类别有正贡献(0.05)。其他特征如X_38 > 25.00和X_30 <= 12.00等都对模型的输出产生了一定影响,但贡献较小,整体来看,模型将“死亡”类的预测归因于特定的特征阈值条件,这些特征值和预测概率被图表直观展示出来

往期推荐

期刊配图:RFE结合随机森林与K折交叉验证的特征筛选可视化

期刊配图:变量重要性排序与顺序正向选择的特征筛选可视化

期刊配图:SHAP可视化改进依赖图+拟合线+边缘密度+分组对比

期刊配图:SHAP蜂巢图与柱状图多维组合解读特征对模型的影响

期刊配图:分类模型对比训练集与测试集评价指标的可视化分析

期刊配图:回归模型对比如何精美可视化训练集与测试集的评价指标

期刊配图:如何同时可视化多个回归模型在训练集与测试集上的预测效果

期刊配图:SHAP可视化进阶蜂巢图与特征重要性环形图的联合展示方法

期刊配图:基于t-sne降维与模型预测概率的分类效果可视化

期刊配图:多种机器学习算法在递归特征筛选中的性能变化图示

picture.image

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论