正文
)
# 参数设置
batch_size =
2
seq_lengths = [
128
,
256
]
# 两个序列的长度
nheads =
16
headdim =
32
dropout_p =
0.0
causal =
True
# 是否使用因果性掩码
scale =
None
# 缩放因子,默认为 1 / sqrt(headdim)
# 为每个序列生成随机的 q, k, v 张量
qs = []
ks = []
vs = []
for
seqlen
in
seq_lengths:
q = torch.randn(seqlen, nheads, headdim, requires_grad=
True
, dtype=torch.bfloat16, device=
"cuda"
)
# (L, nheads, headdim)
k = torch.randn(seqlen, nheads, headdim, requires_grad=
True
, dtype=torch.bfloat16, device=
"cuda"
)
v = torch.randn(seqlen, nheads, headdim, requires_grad=
True
, dtype=torch.bfloat16, device=
"cuda"
)
qs.append(q)
ks.append(k)
vs.append(v)
# 将所有序列的 q, k, v 拼接起来
q_total = torch.cat(qs, dim=
0
)
# (total_q, nheads, headdim)
k_total = torch.cat(ks, dim=
0
)
v_total = torch.cat(vs, dim=
0
)
# 计算累积序列长度,用于索引
cu_seqlens_q = torch.zeros(batch_size +
1
, dtype=torch.int32, device=
"cuda"
)
cu_seqlens_q[
1
:] = torch.cumsum(torch.tensor(seq_lengths, dtype=torch.int32), dim=
0
)
cu_seqlens_k = cu_seqlens_q.clone()
print(
'cu_seqlens_q: '
, cu_seqlens_q)
# 最大序列长度
max_seqlen_q = max(seq_lengths)
max_seqlen_k = max(seq_lengths)
# 任意传入一个softmax_scale
softmax_scale =
0.2
# 调用 flash_attn_varlen_func 函数
out_flash = flash_attn_varlen_func(
q_total,
k_total,
v_total,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
)
# 使用朴素实现对每个序列进行计算,并将输出拼接起来
outputs_naive = []
for
i
in
range(batch_size):
q = qs[i]
# (L_i, nheads, headdim)
k = ks[i]
v = vs[i]
out = scaled_dot_product_attention(
q,
k,
v,
attn_mask=
None
,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale
)
# 输出形状为 (L_i, nheads, headdim)
outputs_naive.append(out)
# 将朴素实现的输出拼接起来
out_naive = torch.cat(outputs_naive, dim=
0
)
# (total_q, nheads, headdim)
print(
'out_naive st: '
, out_naive.flatten()[:
10
])
print(
'out_flash st: '
, out_flash.flatten()[:
10
])
print(
'='
*
20
)
print(
'out_naive en: '
, out_naive.flatten()[
-10
:])
print(
'out_flash en: '
, out_flash.flatten()[
-10
:])
# 比较两个实现的输出是否一致
assert
torch.allclose(out_flash, out_naive, atol=
1e-2
),
"Outputs do not match!"
print(
"测试通过"
)
这个测试是可以通过的,相信通过上面2个对上层接口调用的例子可以让我们对Flash Attention的接口调用有比较清晰的认识。下面我们可以关注一下Flash Attention这个借口的实现,我们不需要深入到cuda实现中,只需要把握一下整体的调用逻辑,搞清楚文章开头抛出的问题即可。
0x4. flash_attn_interface.py中的上层接口
flash-attention 库中使用 cuda 实现了Flash Attention的计算,然后通过 Torch Binding 将
varlen_fwd
这个接口暴露给Python,而
flash_attn_varlen_func
则是对
varlen_fwd
的进一步封装,我们可以在 https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py 中查看到
flash_attn_varlen_func
这个接口的实现。去掉了反向相关的逻辑,如下所示:
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor, # Q序列的累积长度
cu_seqlens_k: torch.Tensor, # K序列的累积长度
max_seqlen_q: int, # Q序列的最大长度
max_seqlen_k: int, # K序列的最大长度
dropout_p: float, # dropout概率
softmax_scale: float, # softmax缩放因子
causal: bool, # 是否使用因果掩码
window_size_left: int = -1, # 滑动窗口左侧大小
window_size_right: int = -1, # 滑动窗口右侧大小
softcap: float = 0.0, # softmax的上限值
alibi_slopes: Optional[torch.Tensor] = None, # ALiBi位置编码的斜率
return_softmax: bool = False, # 是否返回softmax结果
block_table: Optional[torch.Tensor] = None, # 分块表
leftpad_k: Optional[torch.Tensor] = None, # K序列左侧填充
seqused_k: Optional[torch.Tensor] = None, # K序列使用的长度
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 确保输入张量是连续的内存布局
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
# 调用CUDA实现的前向传播函数
out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q, k, v,
None, # 原始掩码矩阵(未使用)
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False, # 未使用的参数
causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
None, # 随机数生成器状态(未使用)
)
return out, softmax_lse, S_dmask, rng_state
# FlashAttnVarlenQKVPackedFunc类实现了PyTorch的自动微分接口
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, # 上下文对象,用于保存反向传播需要的信息
qkv, # 打包的QKV张量
cu_seqlens, # 累积序列长度
max_seqlen, # 最大序列长度
dropout_p, # dropout概率
softmax_scale, # softmax缩放因子
causal, # 是否使用因果掩码
window_size, # 滑动窗口大小
softcap, # softmax上限值
alibi_slopes, # ALiBi位置编码斜率
deterministic, # 是否确定性计算
return_softmax, # 是否返回softmax结果
):
# 如果未指定缩放因子,使用默认的1/sqrt(head_dim)
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
# 分离Q、K、V并detach,避免建立反向图
q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
# 获取原始head size
head_size_og = q.size(2)
# 如果head size不是8的倍数,进行padding
if head_size_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
# 调用前向计算函数
out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
)
# 移除padding,恢复原始head size
out = out_padded[..., :head_size_og]
# 根据需要返回softmax结果
return out if not return_softmax else (out, softmax_lse, S_dmask)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
):
return FlashAttnVarlenFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
block_table,
)
上面这段代码清晰展示了
flash_attn_varlen_func
这个接口的调用逻辑,接下来我们就可以去看一下
flash_attn_cuda.varlen_fwd
这个接口的具体dispatch逻辑了。
0x5. flash_attn_cuda.varlen_fwd的初步dispatch逻辑
首先来到这里:https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L1518 ,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
可以发现
flash_attn_cuda.varlen_fwd
接口对应了
mha_varlen_fwd
这个c++函数。从这里我们应该就可以看到flash attention forward的dispatch逻辑了。
std::vector<:tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q为每个batch中序列长度的总和
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k为每个batch中序列长度的总和,如果有block_table则为num_blocks x page_block_size x num_heads_k x head_size
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k为每个batch中序列长度的总和,如果有block_table则为num_blocks x page_block_size x num_heads_k x head_size
c10::optional<:tensor> &out_, // total_q x num_heads x head_size, total_q为每个batch中序列长度的总和
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<:tensor> &seqused_k, // b。如果提供了该参数,则每个batch元素只使用这么多个key
c10::optional<const