专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】CUDA Coalesced ... ·  昨天  
GiantPandaLLM  ·  【博客转载】C++/CUDA Data ... ·  2 天前  
GiantPandaLLM  ·  【博客转载】CUDA Kernel ... ·  3 天前  
51好读  ›  专栏  ›  GiantPandaLLM

[vLLM实践][算子] vLLM算子开发流程: "保姆级"详细记录

GiantPandaLLM  · 公众号  · 3D  · 2025-06-03 21:58

正文

请到「今天看啥」查看全文


= prefix_output * p_scale + suffix_output * s_scale
return output , output_lse

0x03 Triton 基础算子

PyTorch实现的版本,当然性能是很低的,因为使用了很多的小op,以及对于Tensor进行了inplace的写操作。因此,vLLM中并不是直接使用PyTorch的实现,而是提供了一个基于Triton实现的kernel。完整代码链接:attention/ops/triton_merge_attn_states.py ( https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_merge_attn_states.py) 。具体如下:

  • 数据load及inf处理

数据load及inf处理
  • safe-softmax:减去最大值

safe-softmax
  • 最后校准 :计算prefix_output和suffix_output各自的scale值,然后求两者的加权和作为最后的输出。

校准

我们看到Triton kernel做的事情和PyTorch实现的一样的,但是将所有的操作都fused到一个kernel中,online判断inf值(寄存器)而不是修改global memory中的值,性能一般来说会更高。这个kernel的调用逻辑如下:

Triton kernel的调用

vLLM里边的实现,给merge_attn_states_kernel,分配(num_tokens, num_query_heads)个thread block,每个block处理当前head的所有值,比如head_size=128,则这个block处理128个值。

0x04 Triton 算子分析

  • 基本分析

上小节我们知道,vLLM里边的实现,给merge_attn_states_kernel,分配(num_tokens, num_query_heads)个thread block,每个block处理当前head的所有值,比如head_size=128,则这个block处理128个值。但是,这样做,会出现一些问题。(1)当num_tokens、num_query_heads很大,而head_size很小(比如32)时,就会导致thread block数过大,每个block处理的数据量又过少,计算密度很小。而且,这种情况下,Triton也不一定能生成高效的kernel(下文会讲到);(2)Triton kernel在调用时会有一定CPU的overhead。

may have CPU overhead
  • Gen code(PTX)分析

这里记录一下一个简单有效的分析Triton kernel的方法(当然ncu,nsys用上就更好了)。通常,我们也想知道,到底Triton实际上生成了啥kernel,比如说,生成的kernel PTX是怎么样的,有没有用上向量化,有没有cp.async,合并访存到底做好了没有。这个时候,我们可以指定TRITON_CACHE_DIR环境变量,把Triton生成的中间IR文件给保存下来,进行分析。

exportTRITON_CACHE_DIR=$(pwd)/cache
pytest -s test_merge_attn_states.py# Triton生成的中间IR cache文件cache git:(dev) ✗ tree .
.
├── ALGAAi8N-ErdaDbXXL8N91RokvTI-e8O2oEwd0SL3N0
│   └── __triton_launcher.so
├── p4IOvvpWkyeVkuyW8j50rO-ANYlCc5AJOEr70sQD93A
│   ├── __grp__merge_attn_states_kernel.json
│   ├── merge_attn_states_kernel.cubin
│   ├── merge_attn_states_kernel.json
│   ├── merge_attn_states_kernel.llir
│   ├── merge_attn_states_kernel.ptx
│   ├── merge_attn_states_kernel.ttgir
│   └── merge_attn_states_kernel.ttir
└── q4oIpkjOtdHHfi8xBkm4jC4JWIk5AjKtN8WRkZb8MD8
    └── cuda_utils.so

这里边,我们主要关注merge_attn_states_kernel.ptx这个PTX文件就可以了。比如,对于当num_tokens=512和num_query_heads=16,head_size=32,生成的PTX部分如下:

        @%p8ld.global.b16{%rs3},[%rd16+0]; // 非向量化load
//......
@%p8ld.global.b16{%rs4},[%rd17+0];
//endinlineasm
.loc18530                         //triton_merge_attn_states.py:85:30
div.full.f32%r15,%r16,%r17;
//......
mov.b32%f49,%r15;
.loc18630                         //triton_merge_attn_states.py:86:30
        //......
mov.b32%r23,%f54;
//begininlineasm
cvt.rn.bf16.f32%rs6,%r23;
//endinlineasm
and.b32  %r30,%r25,96;
setp.eq.s32%p10,%r30,0;
//begininlineasm
@%p10st.global.b16[%rd18+0],{%rs6}; // 非向量化store

我们能看到,这种情况下,Triton并没有生成高效的向量化ld/st指令,而是使用ld.global.b16和st.global.b16。因此,如果我们自定义CUDA Kernel,并且手工确保合并访存的话,应该会有一定的性能收益。

0x05 CUDA 算子优化

因为merge_attn_states的逻辑本来就很简单,因此,可以很快就搓一个对应的CUDA实现,这类型kernel最主要的优化就是合并访存。首先,在 vllm/csrc/attention 中添加 merge_attn_states.cu :

merge_attn_states.cu

由于是新增加的cu文件,记得在CMakeList.txt中也添加一下:

修改CMakeList.txt

最终,实现的CUDA算子如下:NUM_THREADS目前默认值设置为128,即128个线程。

namespacevllm{// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005// can be used to combine partial attention results (in the split-KV case)template<typenamescalar_t,constuintNUM_THREADS>__global__voidmerge_attn_states_kernel(
    scalar_t*output,float*output_lse,constscalar_t*prefix_output,
    constfloat*prefix_lse,constscalar_t*suffix_output,
    constfloat*suffix_lse,constuintnum_tokens,constuintnum_heads,
    constuinthead_size){
usingpack_128b_t=uint4;
constuintpack_size=16/sizeof(scalar_t);
constuintthreads_per_head=head_size/pack_size;

constuintglobal_idx=blockIdx.






请到「今天看啥」查看全文