转载自:Coggle数据科学
关注我们,一起学习
在表格数据领域,梯度提升决策树(GBDT)如 XGBoost、LightGBM 和 CatBoost 长期占据统治地位。尽管近年来涌现出许多复杂的表格深度学习模型(如基于 Attention 的架构),但它们往往存在两个缺陷:
- 训练极慢 :相比 GBDT 耗时成倍增加。
- 极度依赖调参 :在默认参数下表现通常弱于 GBDT。
本文的研究目标是:通过元学习(Meta-learning)的思想,寻找在不同数据集上都能表现稳健的“万能”默认参数,并证明一个优化得当的简单 MLP(即 RealMLP)在效率和精度上均能与 GBDT 匹敌。
https://arxiv.org/pdf/2407.04491
Better by Default: Strong Pre-Tuned MLPs and Boosted Trees on Tabular Data
https://github.com/dholzmueller/pytabkit
作者认为,MLP 在表格数据上的弱势并非架构本身的问题,而是缺乏针对表格特性的“全套配套方案”。RealMLP 引入了一系列“Tricks”组合:
- 数据预处理(Preprocessing) :
- 鲁棒缩放与平滑裁剪(Smooth Clipping)
- 数值嵌入(Numerical Embeddings)
测试标准设置
研究者并未将所有数据集混为一谈,而是根据用途将其划分为三个独立的基准池:
- 元训练集 :包含 118 个来自 UCI 库的中等规模数据集。这是“调参”的战场,用于寻找那组“万能”的默认参数。
- 元测试集 :包含 90 个来自 AutoML 和 OpenML-CTR23 的数据集。这些数据在特征维度、类别数量和样本量上更为“极端”,用于测试默认参数的 泛化能力(Out-of-distribution) 。
- Grinsztajn 标杆 :这是一个严格筛选的基准,专门用于运行那些计算昂贵的 Baseline 模型(限制在 10k 样本以内),确保对比的权威性。
如何衡量一个模型在几十个任务上的综合表现?简单的算术平均值往往会被个别极难的任务(误差很大)带偏。作者提出使用 平移几何平均误差(Shifted Geometric Mean Error, ) 。
其背后的逻辑极具深度:
- 相对改进优于绝对改进 :在机器学习中,将错误率从 2% 降到 1% 的难度和价值,通常远高于从 42% 降到 41%。几何平均值能更好地捕捉这种百分比层面的优化。
- 数值稳定性 :为了防止某些数据集出现 0 错误率导致几何平均值变为 0,作者引入了平移项 。
RealMLP:深度学习表格模型
RealMLP 的成功并非源于某种复杂的注意力机制,而是源于对数据预处理、架构设计、初始化和训练策略的极致优化。作者将其称为一组“Tricks”的有机结合。
表格数据最棘手的问题之一是离群点(Outliers) 。
- 鲁棒缩放(Robust Scaling)
:RealMLP 放弃了易受极端值影响的标准差缩放,转而使用基于分位数的
RobustScaler逻辑。
- 平滑裁剪(Smooth Clipping) :引入了一个非线性函数 。
- 深度洞察 :传统的硬裁剪(Hard Clipping)会导致梯度消失,而这种平滑裁剪在输入超过 时会优雅地饱和。这既保留了梯度流,又限制了离群点对神经元的过度冲击。
RealMLP 在标准 MLP 基础上增加了几个关键组件,使其具备了类似于特征工程的能力:
- PBLD 数值嵌入 : 传统的神经网络直接输入标量,而 RealMLP 使用了 周期性偏差线性 DenseNet (PBLD) 嵌入。
- 可学习缩放层(Scaling Layer) : 在第一层线性层之前,引入了一个对角权重矩阵。这相当于一个 自动特征选择器 ,模型可以自主学习给哪些特征“加权”,给哪些“关小音量”。
- 神经切线参数化(NTP) : 为了防止特征维度过大导致梯度爆炸,RealMLP 采用了 NTP 参数化,根据输入维度自动调整学习率权重。
- 参数化激活函数 : 模型使用了类似 PReLU 的变体:。
通过图 1(c) 的消融实验,我们可以清晰地看到每一步改进带来的增益。最显著的性能跃升来自:
- 鲁棒缩放与平滑裁剪 (解决了数据质量问题)。
- 数值嵌入与缩放层 (增强了特征表达)。
- 周期性学习率调度 (优化了收敛质量)。
与GBDT对比的实验
尽管 XGBoost、LightGBM 和 CatBoost 都有各自的官方默认参数,但这些参数往往并非针对所有场景最优。作者通过元学习(Meta-learning)方法,为这三大神器重新定义了“预调优默认值(TD)”。
- 样本采样(Subsampling)是关键 :在所有调优后的默认参数中,行采样(Row Subsampling)被证明几乎总是有效的,而列采样则较少被使用。
- 任务敏感性 :回归任务通常需要比分类任务更深的树结构。
- CatBoost 的速度优化 :Bernoulli Bootstrap 在保持竞争力的同时,比传统的 Bayesian 方式更快。
作者在包含 90 个数据集的元测试集(Meta-test)以及著名的 Grinsztajn 榜单上进行了全方位对比。
实验表明,RealMLP-TD 在中大型数据集上的表现不仅超越了大多数神经网络(如 ResNet, FT-Transformer),而且在分类和回归任务的 SGM 误差指标上,足以与经过调优的 GBDT 家族并驾齐驱。
将 RealMLP 的预处理和嵌入技巧应用到 TabR (一种基于检索的表格模型)后,产生的 RealTabR-D 在回归任务中表现异常出色,甚至在多个基准测试中夺魁。
这是本论文最具工程指导意义的发现。作者对比了两种策略:
- 单算法 HPO :针对一个算法(如 XGBoost)进行 50 步超参数搜索。
- Best-TD / Ensemble-TD :直接运行 XGB、LGBM、CatBoost 和 RealMLP 的默认调优版本,然后选最好的或进行集成。
结果显示: 使用 Best-TD (即从几个强力默认模型中选最优)的训练速度比单算法 HPO 更快 ,且最终准确率更高或相当 。
交流群:点击“联系作者”--备注“研究方向-公司或学校”
欢迎|论文宣传|合作交流
往期推荐
微软 RecAI:使用LLM助力生成式推荐(多项工作均已开源)
图片
长按关注,更多精彩
图片
点个在看你最好看
