点击下方
卡片
,关注“
慢慢学AIGC
”
前言
2 月 21 日,DeepSeek 在 X 上预告了未来一周即将开源 5 个 repo,分享他们在 DeepSeek 模型开发过程中积累的一些构建模块。
今天(2 月 24 日,周一),FlashMLA 如约而至。
接下来我们一探究竟。
编译 & 安装
Github:https://github.com/deepseek-ai/FlashMLA
代码结构很简单,一个 Python 文件 flash_mla_interface.py 为接口,具体内容在 csrc 目录下由 C++/CUDA 实现。
编译和安装步骤只需要一步:
python setup.py install
编译成功后可以看到版本信息为
flash-mla--1.0.0+414a2f3
上述输出会随着社区迭代而发生变化。需要复现本文结果可以保持相同版本。
性能评测
在 PyTorch 中使用 FlashMLA 比较简单:
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True,
)
...
在 H100 上运行代码中的 benchmark,结果如下:
# python tests/test_flash_mla.py
b=128, s_q=1, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.215 ms, 85 TFLOPS, 2832 GB/s
b=128, s_q=1, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.230 ms, 80 TFLOPS, 2662 GB/s
b=128, s_q=2, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.218 ms, 168 TFLOPS, 2815 GB/s
b=128, s_q=2, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.241 ms, 157 TFLOPS, 2632 GB/s
b=128, s_q=1, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.217 ms, 168 TFLOPS, 2819 GB/s
b=128, s_q=1, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.246 ms, 156 TFLOPS, 2620 GB/s
b=128, s_q=2, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.223 ms, 328 TFLOPS, 2794 GB/s
b=128, s_q=2, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.248 ms, 293 TFLOPS, 2493 GB/s
b=128, s_q=1, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.223 ms, 328 TFLOPS, 2794 GB/s
b=128, s_q=1, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.245 ms, 292 TFLOPS, 2486 GB/s
b=128, s_q=2, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.282 ms, 518 TFLOPS, 2268 GB/s
b=128, s_q=2, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.304 ms, 500 TFLOPS, 2187 GB/s
b=128, s_q=1, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.285 ms, 512 TFLOPS, 2245 GB/s
b=128, s_q=1, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.304 ms, 486 TFLOPS, 2129 GB/s
b=128, s_q=2, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.535 ms, 546 TFLOPS, 1261 GB/s
b=128, s_q=2, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.537 ms, 561 TFLOPS, 1293 GB/s
b=128, s_q=1, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.405 ms, 90 TFLOPS, 2995 GB/s
b=128, s_q=1, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.418 ms, 87 TFLOPS, 2903 GB/s
b=128, s_q=2, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.408 ms, 179 TFLOPS, 2985 GB/s
b=128, s_q=2, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.455 ms, 172 TFLOPS, 2863 GB/s
b=128, s_q=1, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.408 ms, 179 TFLOPS, 2984 GB/s
b=128, s_q=1, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.410 ms, 170 TFLOPS, 2833 GB/s
b=128, s_q=2, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.413 ms, 354 TFLOPS, 2971 GB/s
b=128, s_q=2, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.425 ms, 327 TFLOPS, 2750 GB/s
b=128, s_q=1, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.413 ms, 354 TFLOPS, 2970 GB/s
b=128, s_q=1, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.428 ms, 327 TFLOPS, 2745 GB/s
b=128, s_q=2, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.526 ms, 555 TFLOPS, 2363 GB/s
b=128, s_q=2, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.561 ms, 536 TFLOPS, 2282 GB/s
b=128, s_q=1, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.540 ms, 541 TFLOPS, 2303 GB/s
b=128, s_q=1, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.568 ms, 550 TFLOPS, 2337 GB/s
b=128, s_q=2, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
0.999 ms, 584 TFLOPS, 1280 GB/s
b=128, s_q=2, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
0.984 ms, 582 TFLOPS, 1277 GB/s
上面结果可以看出,FlashMLA 最高性能达到近 3000GB/s(访存密集型算子) 和 580+ TFLOPS(计算密集型算子)。
查阅 H100 Datasheet,可以发现实测结果距离硬件峰值算力 1979 TFLOPS 还有差距,但已经很接近峰值访存带宽 3350 GB/s。
另外,FlashMLA 目前提供的代码只支持 bf16,而 V3 论文中的 FP8 Kernel 实现没有随着 FlashMLA 一起发布,期待接下来几天的精彩内容。
MLA 架构细节
在 DeepSeek V2 中首次提出 MLA(多头潜注意力)。MLA 通过将键值(KV)缓存显著压缩为潜在向量,确保了高效的推理。通过配备低秩键值联合压缩,MLA 实现了比 MHA 更好的性能,但需要显著更少的 KV 缓存。
论文链接:https://arxiv.org/pdf/2405.04434
DeepSeek V3/R1 模型和 V2、V2.5 的 Attention 模块均基于 MLA, FlashMLA 的性能优化对这些模型都能同样适用。
扫描下方
二维码
,关注“
慢慢学AIGC
”