加快大模型训练、减少显存消耗的很多技巧

火山方舟小程序RTC

“ 今天文章发晚了,编辑了好久。。。

picture.image

summary
No.Method耗时 (s)显存 (GB)耗时节省比例 (%)显存节省比例 (%)误差
1hf QLoRA59416.71.0202
2Reduce data upcasting46515.521.7%7.2%1.0203
3Bitsandbytes 类型修正42415.38.9%1.3%1.0208
4pytorch自带的attention计算41814.91.4%2.6%1.0214
5pytorch自带的attention计算 causal = True38414.98.1%0.0%1.0219
6Xformers3539.18.1%38.9%1.021
7Flash Attention 23539.10.0%0.0%1.0215
8triton实现rope32697.6%1.1%1.0211
9triton实现layer norm31693.1%0.0%1.021
10triton实现交叉熵3157.40.4%17.8%1.021
11MLP lora反向传播优化3026.84.0%8.1%1.0222
12QKV lora反向传播优化2976.81.7%0.0%1.0217
  1. Reduce data upcasting

处理一些qlora训练里边,upcasting到float32的操作,可以节省7.2%的VRAM,让训练时间减少21.7%,不会带来误差。

  1. Bitsandbytes 类型修正

Bitsandbytes中使用float16,所以需要进行额外的内存复制,打个补丁,将其转换为bfloat16。处理这个问题,可以节省9%的时间。

  1. pytorch自带的attention计算


        
          
from torch.nn.functional import scaled_dot_product_attention  

      

可以使用Pytorch 本身的fast attention,可以节省1.4%的时间

4.5.6 Causal Masking, Xformers, Flash Attention 2

通过使用Causal Masking 传参,差不多快了8.1%;使用Xformers,快了8.1%,并且节省了39%的显存。使用Flash Attention v2没有明显的变化了,因为Xformers内部调用了Flash Attention v2。

7.triton实现ROPE

通过用triton实现RoPE位置编码,可以节省7.6%的训练时间。但是需要实现ROPE的导数。RoPE可以重写为旋转矩阵R和原始矩阵Q之间的矩阵乘法。如果改成这样,导数就是简单的R的转置。然后可以看到R的转置只是相乘的sin值的负值,所以RoPE的导数就是它本身,只是带了个负号。

  1. triton实现RMS Layernorm

RMS Layernorm的导数比较复杂。如上通过链式法则并推导会比较复杂。改成triton实现,可以降低3.1%的训练耗时。

  1. triton实现交叉熵

使用对数技巧,其中 x = exp(log(x)) 来推导导数。logsumexp函数的导数是softmax函数。可以使显存消耗减少17%。

10, 11. lora反向传播优化

注意括号!!!,可以通过合适的括号让Lora微调过程大幅减少实际计算FLOP数量!PyTorch 的自动微分会从末尾向开头反向传播。通过将多个操作融合为一个,并通过正确的链式矩阵乘法进行括号化,可以减少实际的FLOP数量。

Pytorch的自动微分,首先要计算X.T和dW的乘积。假设X的大小为(bsz, seq_len, d)。然后我们将X reshape成大小为(m, d)的张量,其中m为bsz * seq_len,d在7b模型中,为4096

dW的大小为(m, h),其中h是MLP中间层的大小。对于7b模型,它是11008,B.T是LoRA权重,大小为(h, r),其中r是LoRA矩阵的秩,一般是8~64。

如上,原生的flops 为 (h * d)(m + r)

如果用括号之后:

加速比例为

通常 r 很小,8~64。m 非常大,比如bs大小为 4,长度为 4096,那么 m = 4 * 4096 = 16384。这使得 (m + r) 几乎取决于m。所以近似得到:

可以约去m,相当于是

对于7b模型h = 11008、 d = 4096 、 r = 16, 可以得到比例为: 186.58

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动大数据容器化构建与落地实践
随着字节跳动旗下业务的快速发展,数据急剧膨胀,原有的大数据架构在面临日趋复杂的业务需求时逐渐显现疲态。而伴随着大数据架构向云原生演进的行业趋势,字节跳动也对大数据体系进行了云原生改造。本次分享将详细介绍字节跳动大数据容器化的演进与实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论