正文
172 1024.0 128.0 1.0 21.088000 56.848001 32.000002
173 1024.0 128.0 2.0 21.984000 55.551998 35.583999
174 1024.0 128.0 4.0 23.024000 55.808000 42.367999
175 1024.0 128.0 8.0 24.992000 56.127999 54.111999
176 1024.0 256.0 1.0 25.072001 52.480001 66.431999
177 1024.0 256.0 2.0 25.264001 52.576002 67.199998
178 1024.0 256.0 4.0 26.848000 52.416001 70.703998
179 1024.0 256.0 8.0 29.120000 57.663999 81.055999
这个也是我入坑SGLang开源的头几个贡献。
0x2.2 biased_grouped_topk 的 fuse kernel 优化
我们在SGLang中针对biased_grouped_topk(https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/topk.py#L144)引入了一个fuse cuda kernel的优化,将十多个算子fuse成一个cuda kernel大大改进了grouped topk这块的性能。之前也写了一篇blog介绍这个优化:
图解DeepSeek V3 biased_grouped_topk cuda融合算子fused_moe_gate kernel
,感兴趣可以查看,这里就不多赘述了。
kernel对应的PR见:https://github.com/sgl-project/sglang/pull/4530
DeepSeek V3/R1 推理性能端到端提升是5%-8%左右。
0x2.3 将Shared Experts和Route Experts融合
具体细节可以参考我之前写的这篇blog,这个优化花了我挺多时间去测试的,也是一个比较solid提升,
分享一个DeepSeek V3和R1中 Shared Experts和普通Experts融合的一个小技巧
。下面是端到端的性能提升图:
0x2.4 Triton Fused MoE Retuning
升级PyTorch Triton版本之后社区小伙伴发现重新Tuning fused MoE kernel之后可以带来明显的性能提升,例如在Triton 3.2.0中如果继续延用Triton 3.1.0 tuning的config,则性能反而会下降,但如果重新Tuning则可以取得相比于Triton 3.1.0更好的性能。https://github.com/sgl-project/sglang/pull/5716 & https://github.com/sgl-project/sglang/pull/5740 ,通过重新Tuning Fused MoE kernel,在DeepSeek V3/R1上取得了性能提升。
我还测试了一下Triton 3.2.0升级为Triton 3.3.0的提升:
基本也是符合这里的Retuning结论的。然后 https://github.com/vllm-project/vllm/pull/17934#issuecomment-2868822690 这里的一个 micro benchmark 性能测试也佐证了这一点。
0x2.5 Fuse routed scaling factor in topk_reduce kernel
把expert计算完之后最后乘以routed_scaling_factor的逻辑fuse到Fused MoE模块的最后那个topk_reduce_sum kernel中,具体可见:https://github.com/sgl-project/sglang/pull/6220 ,端到端提升如下:
0x2.6 一些额外的探索
在SGLang中也探索了一下TP模式下的基于Cutlass Grouped GEMM和DEEPGEMM的fused moe kernel实现,并跑通了正确性测试和性能测试,但是在TP8模式下相比于DeepSeek V3/R1的Triton Fused MoE kernel没有看到性能提升效果,这里就不详细介绍了。
0x3. Attention Backend的优化
0x3.1 Flash Attention V3 Backend
来自LinkedIn的优化,详情可见
在 SGLang 中实现 Flash Attention 后端 - 基础和 KV 缓存
,端到端吞吐提升结果如下所示: