关注我们,一起学习
标题:Beyond KAN: Introducing KarSein for Adaptive High-Order Feature Interaction Modeling in CTR Prediction
论文链接:https://arxiv.org/pdf/2408.08713
代码:https://github.com/Ancientshi/KarSein
学校:悉尼科技大学
1 引言
传统的特征交互方法通常根据先验知识预定义最大交互阶数,这限制模型学习的充分性,且高阶的特征交互也会增加计算开销。为实现自适应的高阶特征交互且保留模型学的充分性,本文提出 KarSein (Kolmogorov-Arnold Represented Sparse Efficient Interaction Network)结构,主要贡献为:
1.探索KAN在CTR预估中的应用,发现其局限性并提出KarSein采用引导符号回归来解决KAN网络在自发学习乘法关系方面的挑战
2.KarSein保持强大的全局解释性,同时能够去除冗余特征,从而形成稀疏的网络结构
2 KAN用于CTR预估
2.1 乘法关系学习
本文发现 Kolmogorov-Arnold 表示定理特别适合表示与乘法关系相结合的特征交互,给定两个基础域的特征和,衍生的二阶特征交互可以被表示为:
其中, , ,,,,由此可看出,KAN可以通过符号回归对特征交互进行建模,与DNN相比具有更高的可解释性和效率。本文发现KAN的符号回归能力对对结构初始化和正则化设置很敏感。在更一般的场景中,KAN 很难自发地执行正确的符号回归来学习高阶乘法特征交互。这说明KAN 与 DNN 类似,在自主学习乘性特征交互方面的能力本质上仍然有限,因此,仅靠 KAN 的自发学习可能无法产生最佳的 CTR 预测结果。
2.2 原生KAN用于CTR
使用结构为 64−64−1 的原生KAN网络在MovieLens-1M数据集上训练,得到AUC为0.8273,而DNN模型AUC为0.8403。这表明原生的KAN用于CTR预估会得到次优的结果,不过作者进一步发现,在优化后的 KAN 中许多网络连接贡献很小,使用网络简化技术,KAN 网络从初始的 64−64−1 个神经元修剪为 1−1− 1 ,这表明每个特征一个激活函数可能就足够了,从而消除了在同一特征上使用多个激活函数的成本,这就是 KAN 中的参数冗余
3 KARSEIN
3.1 网络结构
如上图,KarSein 模型由多个堆叠的 KarSein 交互层构成。该层以一组嵌入向量作为输入,并生成高阶特征交互。特征交互由嵌入向量的三个步骤组成:可选的成对乘法、可学习的激活变换、线性组合。
首先描述下第L层的KarSein交互层,其输入的维度为输出的维度为。对于第L层,令表示输入矩阵,表示输出矩阵,其中第0层输入且表示由嵌入向量堆叠成的输入,同理最后一层输出因为最终输出神经元为1
可选的成对乘法 :本文的方法中,在和之间进行特征交互,此过程涉及计算成对特征的哈达玛积,然后将计算结果和连接得到新的矩阵来替代作为后续步骤的输入。通常仅在前两个交互层中使用这个可选的成对乘法步骤就足以指导模型的学习来结合乘法关系。由网络结构图可知,成对乘法生成二阶特征,通过高阶激活和线性变换,这些特征演变为更复杂的多元高阶交互如三变量的6阶交互。虽然f6不能直接捕获所有三变量交互,但堆叠额外的 KarSein 交互层可以有效解决这个问题。例如,在第二个 KarSein 层中,经过进一步激活,产生更复杂的交互,如它涵盖了所有三个变量之间的特征交互,这样模型就可以成功学习更丰富的乘法关系。
激活变换 :将网格大小为g、阶数为k的B样条曲线的基函数表示为,对于第L层的输入矩阵首先激活每一行得到,定义可学习的权重矩阵, 接着通过激活变换得到
线性组合 :定义权重矩阵,为了建模特征交互,对激活的嵌入向量进行线性组合并表示为,为增强模型的表现力引入额外的残差连接。对X使用SILU(·)激活函数并定义另一个权重矩阵对激活的嵌入执行线性变换,因此最终的输出形式为
集成隐式交互 :该交互侧重于按位级特征交互,如图所示采用并行网络架构,将向量交互和按位交互的建模分开,两个网络共享相同的嵌入层。
3.2 CTR预估
KarSein结构的显示特征交互输出为,隐式特征交互输出为,将其组合即得到最终的网络输出
其中Wo表示回归参数。训练过程中,KAN 网络通过对激活函数的参数应用 L1 正则化、对激活值应用熵正则化来表现出稀疏性。本文模型继承这一特性并提高了效率。具体地,没有对激活函数的参数应用 L1 正则化、对中间输入输出特征的激活后值应用熵正则化,而是将 L1 和熵正则化结合到 KarSein 交互层的线性组合步骤中,以消除冗余的隐藏神经元。
对于第L层的KarSein交互层的两个权重参数应用L1正则,接着计算熵正则,形式为
其中H代表计算熵,整体的训练目标为
4 实验结果
交流群:点击“联系 作者”--备注“研究方向-公司或学校”
欢迎|论文宣传|合作交流
往期推荐
腾讯 | MTMT: 促进用户增长的多干预多任务uplift模型
CIKM2024 | LightGODE: 挑战传统图推荐范式, 基于轻量级图ODE推荐算法
长按关注,更多精彩
点个在看你最好看