专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】C++/CUDA Data ... ·  昨天  
GiantPandaLLM  ·  【博客转载】CUDA Kernel ... ·  2 天前  
GiantPandaLLM  ·  [Triton编程][基础]vLLM ... ·  2 天前  
51好读  ›  专栏  ›  GiantPandaLLM

[Triton编程][基础]vLLM Triton Merge Attention States K...

GiantPandaLLM  · 公众号  · 3D  · 2025-06-12 19:52

正文

请到「今天看啥」查看全文


https://zhuanlan.zhihu.com/p/695799736)

本文内容包括以下部分:

  • 0x00 前言

  • 0x01 Merge Attention States 简介

  • 0x02 PyTorch实现

  • 0x03 Triton 基础算子

  • 0x04 Triton 算子分析

  • 0x05 NCU Profile分析

  • 0x06 性能评估

  • 0x07 总结

0x01 Merge Attention States 简介

本小节简单介绍一下Merge Attention States的概念。Merge Attention States在FlashInfer: https://www.arxiv.org/pdf/2501.01005 的论文中2.2 Attention Composition小节中出现,然后在vLLM的Triton MLA实现中也被使用到。

Merge Attention States

我们知道,Attention的计算是可以分块的。Block-Parallel Transformer (BPT)表明,对于相同的query以及不同的key/value,Attention Output(O)可以通过同时保留每个块的O及其缩放比例LSE来进行组合。其实就是,在decode阶段,我们们通常面临的是query很小,比如1,但是key和value很长,seqlen长度。因此,对于长序列,可以考虑对key/value先分块,每个块各自计算自己的Attention结果,记录块对应的LSE,最后通过缩放比例来合并。这就是所谓的 ”Merge Attention States “。这种用法,在Chunked-Prefill、Prefix-Cache和Split-KV的场景都会有意义。设 q 为一个query, 为一个索引集(也就是tokens)。 LSE,log-exp-sum 可以定义为:

其实,Merge Attention States要做的事情很简单,就是对两个分块的Attention进行最终的校准。

0x02 PyTorch实现

首先,来简单写一个PyTorch版本的,方便后边和CUDA、Triton算子对数值精度。

# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005# can be used to combine partial attention results (in the split-KV case)def merge_attn_states_torch(        output: torch.Tensor,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]        prefix_output: torch.Tensor,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]        prefix_lse: torch.Tensor,  # [NUM_HEADS, NUM_TOKENS]        suffix_output: torch.Tensor,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]        suffix_lse: torch.Tensor,  # [NUM_HEADS, NUM_TOKENS]        output_lse: Optional[torch.Tensor] = None,  # [NUM_HEADS, NUM_TOKENS]):    p_lse = prefix_lse    s_lse = suffix_lse    # inf -> -inf 这里是为了避免inf值导致output为NAN, exp(inf)=nan, exp(-inf)=0    p_lse[p_lse == torch.inf] = -torch.inf    s_lse[s_lse == torch.inf] = -torch.inf    # max_lse [NUM_HEADS, NUM_TOKENS]    max_lse = torch.maximum(p_lse, s_lse)    # 减去最大值,safe softmax常规操作    p_lse = p_lse - max_lse    s_lse = s_lse - max_lse    p_lse_exp = torch.exp(p_lse)    s_lse_exp = torch.exp(s_lse)    out_se = (p_lse_exp + s_lse_exp)






请到「今天看啥」查看全文