正文
https://zhuanlan.zhihu.com/p/695799736)
本文内容包括以下部分:
0x01 Merge Attention States 简介
本小节简单介绍一下Merge Attention States的概念。Merge Attention States在FlashInfer:
https://www.arxiv.org/pdf/2501.01005
的论文中2.2 Attention Composition小节中出现,然后在vLLM的Triton MLA实现中也被使用到。
Merge Attention States
我们知道,Attention的计算是可以分块的。Block-Parallel Transformer (BPT)表明,对于相同的query以及不同的key/value,Attention Output(O)可以通过同时保留每个块的O及其缩放比例LSE来进行组合。其实就是,在decode阶段,我们们通常面临的是query很小,比如1,但是key和value很长,seqlen长度。因此,对于长序列,可以考虑对key/value先分块,每个块各自计算自己的Attention结果,记录块对应的LSE,最后通过缩放比例来合并。这就是所谓的
”Merge Attention States
“。这种用法,在Chunked-Prefill、Prefix-Cache和Split-KV的场景都会有意义。设
q
为一个query,
为一个索引集(也就是tokens)。
LSE,log-exp-sum
可以定义为:
其实,Merge Attention States要做的事情很简单,就是对两个分块的Attention进行最终的校准。
0x02 PyTorch实现
首先,来简单写一个PyTorch版本的,方便后边和CUDA、Triton算子对数值精度。
def merge_attn_states_torch(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
):
p_lse = prefix_lse
s_lse = suffix_lse
p_lse[p_lse == torch.inf] = -torch.inf
s_lse[s_lse == torch.inf] = -torch.inf
max_lse = torch.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
s_lse_exp = torch.exp(s_lse)
out_se = (p_lse_exp + s_lse_exp)