专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

梳理下Flash Attention的dispatch逻辑(文末免费送书)

GiantPandaLLM  · 公众号  · 3D  · 2024-11-12 12:00

正文

请到「今天看啥」查看全文


)

# 参数设置
batch_size = 2
seq_lengths = [ 128 , 256 ] # 两个序列的长度
nheads = 16
headdim = 32
dropout_p = 0.0
causal = True # 是否使用因果性掩码
scale = None # 缩放因子,默认为 1 / sqrt(headdim)

# 为每个序列生成随机的 q, k, v 张量
qs = []
ks = []
vs = []
for seqlen in seq_lengths:
q = torch.randn(seqlen, nheads, headdim, requires_grad= True , dtype=torch.bfloat16, device= "cuda" ) # (L, nheads, headdim)
k = torch.randn(seqlen, nheads, headdim, requires_grad= True , dtype=torch.bfloat16, device= "cuda" )
v = torch.randn(seqlen, nheads, headdim, requires_grad= True , dtype=torch.bfloat16, device= "cuda" )
qs.append(q)
ks.append(k)
vs.append(v)

# 将所有序列的 q, k, v 拼接起来
q_total = torch.cat(qs, dim= 0 ) # (total_q, nheads, headdim)
k_total = torch.cat(ks, dim= 0 )
v_total = torch.cat(vs, dim= 0 )

# 计算累积序列长度,用于索引
cu_seqlens_q = torch.zeros(batch_size + 1 , dtype=torch.int32, device= "cuda" )
cu_seqlens_q[ 1 :] = torch.cumsum(torch.tensor(seq_lengths, dtype=torch.int32), dim= 0 )
cu_seqlens_k = cu_seqlens_q.clone()

print( 'cu_seqlens_q: ' , cu_seqlens_q)

# 最大序列长度
max_seqlen_q = max(seq_lengths)
max_seqlen_k = max(seq_lengths)

# 任意传入一个softmax_scale
softmax_scale = 0.2

# 调用 flash_attn_varlen_func 函数
out_flash = flash_attn_varlen_func(
q_total,
k_total,
v_total,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
)

# 使用朴素实现对每个序列进行计算,并将输出拼接起来
outputs_naive = []
for i in range(batch_size):
q = qs[i] # (L_i, nheads, headdim)
k = ks[i]
v = vs[i]
out = scaled_dot_product_attention(
q,
k,
v,
attn_mask= None ,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale
) # 输出形状为 (L_i, nheads, headdim)
outputs_naive.append(out)

# 将朴素实现的输出拼接起来
out_naive = torch.cat(outputs_naive, dim= 0 ) # (total_q, nheads, headdim)



print( 'out_naive st: ' , out_naive.flatten()[: 10 ])
print( 'out_flash st: ' , out_flash.flatten()[: 10 ])
print( '=' * 20 )
print( 'out_naive en: ' , out_naive.flatten()[ -10 :])
print( 'out_flash en: ' , out_flash.flatten()[ -10 :])

# 比较两个实现的输出是否一致
assert torch.allclose(out_flash, out_naive, atol= 1e-2 ), "Outputs do not match!"

print( "测试通过" )

这个测试是可以通过的,相信通过上面2个对上层接口调用的例子可以让我们对Flash Attention的接口调用有比较清晰的认识。下面我们可以关注一下Flash Attention这个借口的实现,我们不需要深入到cuda实现中,只需要把握一下整体的调用逻辑,搞清楚文章开头抛出的问题即可。

0x4. flash_attn_interface.py中的上层接口

flash-attention 库中使用 cuda 实现了Flash Attention的计算,然后通过 Torch Binding 将 varlen_fwd 这个接口暴露给Python,而 flash_attn_varlen_func 则是对 varlen_fwd 的进一步封装,我们可以在 https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py 中查看到 flash_attn_varlen_func 这个接口的实现。去掉了反向相关的逻辑,如下所示:

def _flash_attn_varlen_forward(
    q: torch.Tensor,
    k: torch.Tensor, 
    v: torch.Tensor,
    cu_seqlens_q: torch.Tensor,  # Q序列的累积长度
    cu_seqlens_k: torch.Tensor,  # K序列的累积长度
    max_seqlen_q: int,          # Q序列的最大长度
    max_seqlen_k: int,          # K序列的最大长度
    dropout_p: float,           # dropout概率
    softmax_scale: float,       # softmax缩放因子
    causal: bool,               # 是否使用因果掩码
    window_size_left: int = -1,  # 滑动窗口左侧大小
    window_size_right: int = -1# 滑动窗口右侧大小
    softcap: float = 0.0,       # softmax的上限值
    alibi_slopes: Optional[torch.Tensor] = None,  # ALiBi位置编码的斜率
    return_softmax: bool = False,  # 是否返回softmax结果
    block_table: Optional[torch.Tensor] = None,  # 分块表
    leftpad_k: Optional[torch.Tensor] = None,    # K序列左侧填充
    seqused_k: Optional[torch.Tensor] = None,    # K序列使用的长度
)
 -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

    # 确保输入张量是连续的内存布局
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    
    # 调用CUDA实现的前向传播函数
    out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
        q, k, v,
        None,  # 原始掩码矩阵(未使用)
        cu_seqlens_q,
        cu_seqlens_k,
        seqused_k,
        leftpad_k,
        block_table,
        alibi_slopes,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,  # 未使用的参数
        causal,
        window_size_left,
        window_size_right,
        softcap,
        return_softmax,
        None,  # 随机数生成器状态(未使用)
    )
    return out, softmax_lse, S_dmask, rng_state

# FlashAttnVarlenQKVPackedFunc类实现了PyTorch的自动微分接口
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,  # 上下文对象,用于保存反向传播需要的信息
        qkv,  # 打包的QKV张量
        cu_seqlens,  # 累积序列长度
        max_seqlen,  # 最大序列长度
        dropout_p,   # dropout概率
        softmax_scale,  # softmax缩放因子
        causal,      # 是否使用因果掩码
        window_size,  # 滑动窗口大小
        softcap,     # softmax上限值
        alibi_slopes,  # ALiBi位置编码斜率
        deterministic,  # 是否确定性计算
        return_softmax,  # 是否返回softmax结果
    )
:

        # 如果未指定缩放因子,使用默认的1/sqrt(head_dim)
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
            
        # 分离Q、K、V并detach,避免建立反向图
        q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
        
        # 获取原始head size
        head_size_og = q.size(2)
        
        # 如果head size不是8的倍数,进行padding
        if head_size_og % 8 != 0:
            q = torch.nn.functional.pad(q, [08 - head_size_og % 8])
            k = torch.nn.functional.pad(k, [08 - head_size_og % 8])
            v = torch.nn.functional.pad(v, [08 - head_size_og % 8])
            
        # 调用前向计算函数    
        out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
            q, k, v,
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size_left=window_size[0],
            window_size_right=window_size[1],
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=return_softmax and dropout_p > 0,
            block_table=None,
        )
        # 移除padding,恢复原始head size
        out = out_padded[..., :head_size_og]
        
        # 根据需要返回softmax结果
        return out if not return_softmax else (out, softmax_lse, S_dmask)

def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1-1),  # -1 means infinite context window
    softcap=0.0# 0.0 means deactivated
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    block_table=None,
)
:

    return FlashAttnVarlenFunc.apply(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        block_table,
    )

上面这段代码清晰展示了 flash_attn_varlen_func 这个接口的调用逻辑,接下来我们就可以去看一下 flash_attn_cuda.varlen_fwd 这个接口的具体dispatch逻辑了。

0x5. flash_attn_cuda.varlen_fwd的初步dispatch逻辑

首先来到这里:https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L1518 ,

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "FlashAttention";
    m.def("fwd", &mha_fwd, "Forward pass");
    m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
    m.def("bwd", &mha_bwd, "Backward pass");
    m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
    m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}

可以发现 flash_attn_cuda.varlen_fwd 接口对应了 mha_varlen_fwd 这个c++函数。从这里我们应该就可以看到flash attention forward的dispatch逻辑了。

std::vector<:tensor>
mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q为每个batch中序列长度的总和
               const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k为每个batch中序列长度的总和,如果有block_table则为num_blocks x page_block_size x num_heads_k x head_size
               const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k为每个batch中序列长度的总和,如果有block_table则为num_blocks x page_block_size x num_heads_k x head_size
               c10::optional<:tensor> &out_, // total_q x num_heads x head_size, total_q为每个batch中序列长度的总和
               const at::Tensor &cu_seqlens_q,  // b+1
               const at::Tensor &cu_seqlens_k,  // b+1
               c10::optional<:tensor> &seqused_k, // b。如果提供了该参数,则每个batch元素只使用这么多个key
               c10::optional<const






请到「今天看啥」查看全文