正文
=
prefix_output
*
p_scale
+
suffix_output
*
s_scale
return
output
,
output_lse
0x03 Triton 基础算子
PyTorch实现的版本,当然性能是很低的,因为使用了很多的小op,以及对于Tensor进行了inplace的写操作。因此,vLLM中并不是直接使用PyTorch的实现,而是提供了一个基于Triton实现的kernel。完整代码链接:attention/ops/triton_merge_attn_states.py
(
https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_merge_attn_states.py)
。具体如下:
数据load及inf处理
safe-softmax
校准
我们看到Triton kernel做的事情和PyTorch实现的一样的,但是将所有的操作都fused到一个kernel中,online判断inf值(寄存器)而不是修改global memory中的值,性能一般来说会更高。这个kernel的调用逻辑如下:
Triton kernel的调用
vLLM里边的实现,给merge_attn_states_kernel,分配(num_tokens, num_query_heads)个thread block,每个block处理当前head的所有值,比如head_size=128,则这个block处理128个值。
0x04 Triton 算子分析
上小节我们知道,vLLM里边的实现,给merge_attn_states_kernel,分配(num_tokens, num_query_heads)个thread block,每个block处理当前head的所有值,比如head_size=128,则这个block处理128个值。但是,这样做,会出现一些问题。(1)当num_tokens、num_query_heads很大,而head_size很小(比如32)时,就会导致thread block数过大,每个block处理的数据量又过少,计算密度很小。而且,这种情况下,Triton也不一定能生成高效的kernel(下文会讲到);(2)Triton kernel在调用时会有一定CPU的overhead。
may have CPU overhead
这里记录一下一个简单有效的分析Triton kernel的方法(当然ncu,nsys用上就更好了)。通常,我们也想知道,到底Triton实际上生成了啥kernel,比如说,生成的kernel PTX是怎么样的,有没有用上向量化,有没有cp.async,合并访存到底做好了没有。这个时候,我们可以指定TRITON_CACHE_DIR环境变量,把Triton生成的中间IR文件给保存下来,进行分析。
export TRITON_CACHE_DIR = $( pwd ) /cache pytest -s test_merge_attn_states.py # Triton生成的中间IR cache文件 cache git: ( dev ) ✗ tree . . ├── ALGAAi8N-ErdaDbXXL8N91RokvTI-e8O2oEwd0SL3N0 │ └── __triton_launcher.so ├── p4IOvvpWkyeVkuyW8j50rO-ANYlCc5AJOEr70sQD93A │ ├── __grp__merge_attn_states_kernel.json │ ├── merge_attn_states_kernel.cubin │ ├── merge_attn_states_kernel.json │ ├── merge_attn_states_kernel.llir │ ├── merge_attn_states_kernel.ptx │ ├── merge_attn_states_kernel.ttgir │ └── merge_attn_states_kernel.ttir └── q4oIpkjOtdHHfi8xBkm4jC4JWIk5AjKtN8WRkZb8MD8 └── cuda_utils.so
这里边,我们主要关注merge_attn_states_kernel.ptx这个PTX文件就可以了。比如,对于当num_tokens=512和num_query_heads=16,head_size=32,生成的PTX部分如下:
@% p8 ld.global.b16 { % rs3 } , [ % rd16 + 0 ] ; // 非向量化load // ...... @% p8 ld.global.b16 { % rs4 } , [ % rd17 + 0 ] ; // end inline asm .loc 1 85 30 // triton_merge_attn_states.py : 85 : 30 div.full.f32 % r15 , % r16 , % r17 ; // ...... mov.b32 % f49 , % r15 ; .loc 1 86 30 // triton_merge_attn_states.py : 86 : 30 // ...... mov.b32 % r23 , % f54 ; // begin inline asm cvt.rn.bf16.f32 % rs6 , % r23 ; // end inline asm and.b32 % r30 , % r25 , 96 ; setp.eq.s32 % p10 , % r30 , 0 ; // begin inline asm @% p10 st.global.b16 [ % rd18 + 0 ], { % rs6 } ; // 非向量化store
我们能看到,这种情况下,Triton并没有生成高效的向量化ld/st指令,而是使用ld.global.b16和st.global.b16。因此,如果我们自定义CUDA Kernel,并且手工确保合并访存的话,应该会有一定的性能收益。
0x05 CUDA 算子优化
因为merge_attn_states的逻辑本来就很简单,因此,可以很快就搓一个对应的CUDA实现,这类型kernel最主要的优化就是合并访存。首先,在
vllm/csrc/attention
中添加
merge_attn_states.cu
:
merge_attn_states.cu
由于是新增加的cu文件,记得在CMakeList.txt中也添加一下:
修改CMakeList.txt
最终,实现的CUDA算子如下:NUM_THREADS目前默认值设置为128,即128个线程。
namespace vllm { // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) template < typename scalar_t , const uint NUM_THREADS > __global__ void merge_attn_states_kernel ( scalar_t * output , float * output_lse , const scalar_t * prefix_output , const float * prefix_lse , const scalar_t * suffix_output , const float * suffix_lse , const uint num_tokens , const uint num_heads , const uint head_size ) { using pack_128b_t = uint4 ; const uint pack_size = 16 / sizeof ( scalar_t ); const uint threads_per_head = head_size / pack_size ; const uint global_idx = blockIdx .