“ 今天文章发晚了,编辑了好久。。。
No. | Method | 耗时 (s) | 显存 (GB) | 耗时节省比例 (%) | 显存节省比例 (%) | 误差 |
---|---|---|---|---|---|---|
1 | hf QLoRA | 594 | 16.7 | 1.0202 | ||
2 | Reduce data upcasting | 465 | 15.5 | 21.7% | 7.2% | 1.0203 |
3 | Bitsandbytes 类型修正 | 424 | 15.3 | 8.9% | 1.3% | 1.0208 |
4 | pytorch自带的attention计算 | 418 | 14.9 | 1.4% | 2.6% | 1.0214 |
5 | pytorch自带的attention计算 causal = True | 384 | 14.9 | 8.1% | 0.0% | 1.0219 |
6 | Xformers | 353 | 9.1 | 8.1% | 38.9% | 1.021 |
7 | Flash Attention 2 | 353 | 9.1 | 0.0% | 0.0% | 1.0215 |
8 | triton实现rope | 326 | 9 | 7.6% | 1.1% | 1.0211 |
9 | triton实现layer norm | 316 | 9 | 3.1% | 0.0% | 1.021 |
10 | triton实现交叉熵 | 315 | 7.4 | 0.4% | 17.8% | 1.021 |
11 | MLP lora反向传播优化 | 302 | 6.8 | 4.0% | 8.1% | 1.0222 |
12 | QKV lora反向传播优化 | 297 | 6.8 | 1.7% | 0.0% | 1.0217 |
- Reduce data upcasting
处理一些qlora训练里边,upcasting到float32的操作,可以节省7.2%的VRAM,让训练时间减少21.7%,不会带来误差。
- Bitsandbytes 类型修正
Bitsandbytes中使用float16,所以需要进行额外的内存复制,打个补丁,将其转换为bfloat16。处理这个问题,可以节省9%的时间。
- pytorch自带的attention计算
from torch.nn.functional import scaled_dot_product_attention
可以使用Pytorch 本身的fast attention,可以节省1.4%的时间
4.5.6 Causal Masking, Xformers, Flash Attention 2
通过使用Causal Masking 传参,差不多快了8.1%;使用Xformers,快了8.1%,并且节省了39%的显存。使用Flash Attention v2没有明显的变化了,因为Xformers内部调用了Flash Attention v2。
7.triton实现ROPE
通过用triton实现RoPE位置编码,可以节省7.6%的训练时间。但是需要实现ROPE的导数。RoPE可以重写为旋转矩阵R和原始矩阵Q之间的矩阵乘法。如果改成这样,导数就是简单的R的转置。然后可以看到R的转置只是相乘的sin值的负值,所以RoPE的导数就是它本身,只是带了个负号。
- triton实现RMS Layernorm
RMS Layernorm的导数比较复杂。如上通过链式法则并推导会比较复杂。改成triton实现,可以降低3.1%的训练耗时。
- triton实现交叉熵
使用对数技巧,其中 x = exp(log(x)) 来推导导数。logsumexp函数的导数是softmax函数。可以使显存消耗减少17%。
10, 11. lora反向传播优化
注意括号!!!,可以通过合适的括号让Lora微调过程大幅减少实际计算FLOP数量!PyTorch 的自动微分会从末尾向开头反向传播。通过将多个操作融合为一个,并通过正确的链式矩阵乘法进行括号化,可以减少实际的FLOP数量。
Pytorch的自动微分,首先要计算X.T和dW的乘积。假设X的大小为(bsz, seq_len, d)。然后我们将X reshape成大小为(m, d)的张量,其中m为bsz * seq_len,d在7b模型中,为4096
dW的大小为(m, h),其中h是MLP中间层的大小。对于7b模型,它是11008,B.T是LoRA权重,大小为(h, r),其中r是LoRA矩阵的秩,一般是8~64。
如上,原生的flops 为 (h * d)(m + r)
如果用括号之后:
加速比例为
通常 r 很小,8~64。m 非常大,比如bs大小为 4,长度为 4096,那么 m = 4 * 4096 = 16384。这使得 (m + r) 几乎取决于m。所以近似得到:
可以约去m,相当于是
对于7b模型h = 11008、 d = 4096 、 r = 16, 可以得到比例为: 186.58