Meta 发表关于multi-token prediction的突破性新论文,以实现更好、更快的 LLM

弹性计算MySQL机器学习

欢迎关注我的公众号“ NLP前沿 ”,日更最新论文/博客速读,周更AI领域近一周发生的那些事儿 。欢迎投稿! 行文仓促,有理解错误,欢迎指正

“ 前几天的文章,放假鸽了几天;看到这个文章,第一印象应该就是把美杜莎推理加速的活用在了预训练上,但是文章中很多实验,不妨细看一下。


        
          
https://arxiv.org/pdf/2404.19737  

      

原理篇类似于medusa decode,如下图,相比于常规的next token prediction任务,多了几个头,在解码的时候可以只取第一个,就回归到常规的clm的推理方式,取多个头,那就是medusa的加速方式。因为clm训练的attention mask本身就是下三角的,不管预测几个都不会出现信息泄露,所以整体上代码改动也比较小。

picture.image

picture.image

实现上有个小技巧,因为头的shape是(hidden_size, vocab_size),如果直接走所有头预测,然后损失加起来反向传播。这个显存占用比较大,会降低显存利用率,论文中特地提到可以通过头遍历,一个头一个头来,如下图:

picture.image

最后看看一些实验对比结果:

  1. 6个尺寸的模型,模型越小,越比不上常规next-token-prediction训练出来的模型,但是模型越大,越受益multi-token-predictionpicture.image
  2. 使用7b模型,头的数量n测试1,2,4,6,8,在mbpp,humaneval,n=4始终优于baseline,但是具体的n的取值,很可能与数据集的分布有关。如果同一份数据训练多个epoch,仍然有一定的优势,但是n=4已经不是最佳的头数了。

picture.image

  1. 在code任务上,使用multi-token prediction训练的模型,微调可以使用ntp任务,或者mtp任务,都能获得比baseline更好的效果;但是对于自然语言选择题的几个基准,使用mtp就优势不大了;对于摘要任务,使用mtp,n=2,n=4都能获得更好的结果picture.image

picture.image

一些延申的实验

  1. mtp在低参数量下,更容易学习到组合关系,比如某2个token AB,在上文出现,后文出现A,mtp训练的模型在30M以下更容易出现,预测B的能力。picture.image
  2. mtp可以提高模型的推理能力以及泛化性能picture.image
  3. 为什么mtp有效?如下图一个序列,可能5 -> A 比较难,其他1,2,3,4,5,与 a,b 的ntp任务都很简单,mtp可以从3开始就预测a,4预测a,b,可以促使模型在5->A做出更正确的决策

picture.image

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论