Nature新算法:准确的小数据预测与表格基础模型TabPFN回归实现及其模型解释

大模型机器学习数据库

picture.image

背景

picture.image

表格数据,作为一种广泛存在于各行各业的数据类型,始终在数据科学中占据着重要地位。从医学、金融到气候变化、电子商务,表格数据无处不在。然而,尽管深度学习在图像、语音等领域取得了巨大成功,但在表格数据的应用上,传统的机器学习方法(如决策树、随机森林、梯度提升等)依然是主流。这些方法虽然在许多场景中表现优秀,但它们在面对小规模数据集时常常力不从心,尤其是在需要大量超参数调优和特征工程的情况下

在这种背景下,2025年1月发表在《Nature》上的一篇名为《Accurate predictions on small data with a tabular foundation model》的研究,为我们带来了突破性的进展。研究提出了 TabPFN (Tabular Prior-data Fitted Network),一个基于变压器(Transformer)的表格数据基础模型。TabPFN通过创新性的“上下文学习”(ICL)机制,能够在仅需数秒钟的时间内,对最多10,000个样本、500个特征的表格数据进行预测,并且无需像传统方法那样进行大量的超参数调优

picture.image

TabPFN的最大亮点在于其训练方式:通过在数百万个合成数据集上进行预训练,TabPFN学会了处理数据缺失、噪声和不重要特征等常见挑战,从而能有效应对各种现实世界中的复杂任务。与传统方法相比,TabPFN不仅能在较少的数据上取得更高的准确度,还能显著提高预测速度,在表格数据建模方面展现出了令人惊叹的潜力,通过这一创新技术,TabPFN不仅为表格数据的处理提供了更加高效的解决方案,也为科学研究和决策制定领域带来了新的机遇

在实际应用 TabPFN 模型时,首先我们需要借助官方提供的基础库和教程进行设置和操作。你可以参考以下链接获取相关资料:

  • Prior Labs 回归任务教程:

https://priorlabs.ai/getting\_started/install/#\_\_tabbed\_1\_3 这是一个详细的回归任务的实现指南,帮助你了解如何使用 TabPFN 进行回归任务

  • TabPFN GitHub 仓库:

https://github.com/PriorLabs/TabPFN?tab=readme-ov-file 提供了完整的代码实现和详细的文档说明,适合用来快速启动并部署模型

接下来,我们将在自己的数据集上实际运用 TabPFN 模型,进行回归任务的训练与预测并进行模型解释。在这个过程中,我们将加载自己的表格数据,并将其输入到 TabPFNRegressor 中进行模型训练和预测

代码实现

数据读取


          
import pandas as pd
          
import numpy as np
          
import matplotlib.pyplot as plt
          
plt.rcParams['font.family'] = 'Times New Roman'
          
plt.rcParams['axes.unicode_minus'] = False
          
import warnings
          
# 忽略所有警告
          
warnings.filterwarnings("ignore")
          
df = pd.read_excel('2025-2-12公众号Python机器学习AI.xlsx')
      

导入必要的库,读取一个名为 '2025-2-12公众号Python机器学习AI.xlsx' 的 Excel 文件,并设置了图表的字体和忽略警告,以便后续的数据处理和可视化操作

数据相关性

picture.image

对我们的数据进行一个初步的探索,绘制一个相关性矩阵可视化,最终观察到SR(目标变量)与其他变量的相关性情况,且数据不存在严重的多重共线性,一般利用相关性系数判断多重共线性时,若自变量之间的相关系数绝对值大于0.8或0.9(这是一个阈值并没有一定固定),则可能存在严重的多重共线性,完整 代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系

模型构建


          
from sklearn.model_selection import train_test_split, KFold
          
X = df.drop(['SR'],axis=1)
          
y = df['SR']
          
# 划分训练集和测试集
          
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, 
          
                                                    random_state=42)
          

          
from tabpfn import TabPFNRegressor
          
# 初始化TabPFN回归模型
          
model = TabPFNRegressor()
          

          
# 训练模型
          
model.fit(X_train, y_train)
      

使用默认参数构建 TabPFN 回归模型 ,首先对数据进行训练集和测试集的划分(SR 作为目标变量),然后初始化并训练 TabPFNRegressor 模型,以默认超参数进行拟合

TabPFNRegressor 模型提供多个可配置的参数,以适应不同的数据集和计算环境,以下是几个关键参数

  • model_path=None:默认使用内置的预训练模型,而不指定自定义路径
  • device="cpu":默认使用 CPU 进行计算,如需 GPU 加速,可设置为 "cuda"
  • inference_precision="float32":默认推理精度为 float32,可调整为 float16 以减少计算量
  • normalize_y=True:默认对目标变量 y 进行归一化,以提高模型稳定性
  • seed=0:默认随机种子为 0,以确保结果的可复现性

默认参数保证了模型的稳定性和兼容性,同时可以根据需求进行调整,以优化性能或适配不同的计算资源,这也是它的一个优点,可以发现它并不像树模型一样存在很多参数(如树深度、树数量、分裂节点等),也就是文献提到的 无需像传统方法那样进行大量的超参数调优

模型性能评价


          
from sklearn import metrics
          

          
# 预测
          
y_pred_train = model.predict(X_train)
          
y_pred_test = model.predict(X_test)
          

          
y_pred_train_list = y_pred_train.tolist()
          
y_pred_test_list = y_pred_test.tolist()
          

          
# 计算训练集的指标
          
mse_train = metrics.mean_squared_error(y_train, y_pred_train_list)
          
rmse_train = np.sqrt(mse_train)
          
mae_train = metrics.mean_absolute_error(y_train, y_pred_train_list)
          
r2_train = metrics.r2_score(y_train, y_pred_train_list)
          

          
# 计算测试集的指标
          
mse_test = metrics.mean_squared_error(y_test, y_pred_test_list)
          
rmse_test = np.sqrt(mse_test)
          
mae_test = metrics.mean_absolute_error(y_test, y_pred_test_list)
          
r2_test = metrics.r2_score(y_test, y_pred_test_list)
          

          
print("训练集评价指标:")
          
print("均方误差 (MSE):", mse_train)
          
print("均方根误差 (RMSE):", rmse_train)
          
print("平均绝对误差 (MAE):", mae_train)
          
print("拟合优度 (R-squared):", r2_train)
          
print(f'-------------------------')
          
print("\n测试集评价指标:")
          
print("均方误差 (MSE):", mse_test)
          
print("均方根误差 (RMSE):", rmse_test)
          
print("平均绝对误差 (MAE):", mae_test)
          
print("拟合优度 (R-squared):", r2_test)
      

picture.image

picture.image

计算并输出了 TabPFN 回归模型训练集测试集 上的回归性能指标,包括 均方误差(MSE)均方根误差(RMSE)平均绝对误差(MAE)拟合优度(R²) 。从结果来看,训练集和测试集的 R² 分数分别为 0.941 和 0.917 ,说明模型在训练和测试集上均具有较高的拟合能力,同时 误差(MSE、RMSE、MAE)较小 ,表明模型的预测误差在可接受范围内,因此可以判断 模型性能表现良好,泛化能力较强,完整 代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系

模型解释


          
from sklearn.inspection import partial_dependence
          
from scipy.interpolate import splrep, splev
          

          
def plot_partial_dependence(model, X, features, grid_resolution=50, sample_size=100):
          
    """
          
    绘制一个或多个特征的偏依赖图(Partial Dependence Plot)。
          

          
    参数:
          
    - model: 训练好的机器学习模型
          
    - X: 特征数据 (DataFrame)
          
    - features: 要绘制的特征,可以是一个字符串(单个特征)或列表(多个特征)
          
    - grid_resolution: 网格分辨率(默认50)
          
    - sample_size: Rugplot中用于展示分布的样本数(默认100)
          
    """
          

          
plot_partial_dependence(model, X_test, 'infP')
      

picture.image

调用自定义的 plot_partial_dependence 函数,使用 TabPFN 回归模型 和测试数据集 X_test ,绘制特征 infP 的偏依赖图(PDP)和个体条件期望图(ICE),以展示 infP 对模型预测的影响

TabPFN 回归模型 中,特征 infP 的影响可以通过 偏依赖图(PDP)个体条件期望图(ICE) 来进行解释。偏依赖图展示了 infP 的变化如何影响模型的平均预测结果,这为我们提供了该特征对目标变量的全局影响。而个体条件期望图则揭示了每个样本在不同 infP 值下的预测变化,帮助我们理解模型在单个样本级别上的反应。通过这两种方法,我们能够更深入地了解 infP 对回归模型输出的作用,进一步提升模型的可解释性。这种解释能够帮助我们识别 infP 与目标变量之间的关系,是否存在线性或非线性影响,以及该特征对预测结果的重要性,当然这里的绘图方法是期刊配图形式把PDP和ICE在一个画布上进行了展示,参考往期文章——期刊配图:模型解释PDP可视化进阶置信区间+拟合曲线

从这张偏依赖图(PDP)可以看出,特征 infP 对模型预测结果的影响表现出一种非线性趋势。随着 infP 值的增加,模型的预测值( Partial Dependence )逐渐下降,表明 infP 的增大可能会导致目标变量的预测值减少,但是最后又有一定的平稳趋势,具体解释如下:

  • 图中展示的曲线平滑且略微呈下降趋势至平稳,说明 infP 的变化对模型输出产生了影响,但影响的程度随着 infP 的取值变化逐渐减弱
  • 上下灰色阴影区域表示个体条件期望(ICE)曲线的变化范围,可以看到在不同 infP 值下,个体预测的波动范围较小,这表明大部分样本的预测结果都遵循类似的模式
  • Rugplot 也显示了 infP 的分布情况, Rugplot 是位于图像底部的一条水平线条,它通过小短线展示了数据在 infP 特征上的分布。每个短线代表一个样本的特征值,它们沿着 infP 的轴分布,显示了样本在该特征上的取值

感兴趣的读者还可以自行尝试其他解释方法,如 SHAP 值LIME ,对该模型进行进一步的解释和分析,完整 代码与数据集获取:如需获取本文的源代码和数据集,请添加作者微信联系

往期推荐

SHAP值+模型预测概率解读机器学习模型的决策过程

聚类与解释的结合:利用K-Means聚类辅助SHAP模型解释并可视化

期刊配图:RFE结合随机森林与K折交叉验证的特征筛选可视化

期刊配图:变量重要性排序与顺序正向选择的特征筛选可视化

期刊配图:SHAP可视化改进依赖图+拟合线+边缘密度+分组对比

期刊配图:SHAP蜂巢图与柱状图多维组合解读特征对模型的影响

期刊配图:分类模型对比训练集与测试集评价指标的可视化分析

期刊配图:回归模型对比如何精美可视化训练集与测试集的评价指标

期刊配图:如何同时可视化多个回归模型在训练集与测试集上的预测效果

期刊配图:SHAP可视化进阶蜂巢图与特征重要性环形图的联合展示方法

期刊配图:基于t-sne降维与模型预测概率的分类效果可视化

期刊配图:多种机器学习算法在递归特征筛选中的性能变化图示

picture.image

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论