正文
Alibi 与相对位置编码类似,但有一个例外——它有一个通常预先计算的per-head因子。
alibi_bias = generate_alibi_bias() # [num_heads] def alibi (score, b, h, q_idx, kv_idx) : bias = alibi_bias[h] * (q_idx - kv_idx) return score + bias
这展示了
torch.compile
提供的一个有趣的灵活性——即使
alibi_bias
没有被显式地作为输入传递进来,我们也可以从中加载数据!生成的 Triton kernel将计算从
alibi_bias
张量中正确加载的数据并将其融合。请注意,即使重新生成
alibi_bias
,我们也不需要重新编译。
Soft-capping
Soft-capping 是一种在 Gemma2(https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) 和 Grok-1 中使用的技术,用于防止 logits 过度增长。在 FlexAttention 中,它看起来像这样:
softcap = 20 def soft_cap (score, b, h, q_idx, kv_idx) : score = score / softcap score = torch.tanh(score) score = score * softcap return score
请注意,我们在这里还自动从正向pass生成反向pass。此外,尽管此实现语义上是正确的,但由于性能原因,我们可能希望在这种情况下使用 tanh 近似。有关更多详细信息,请参见 attention-gym(https://github.com/pytorch-labs/attention-gym/blob/main/attn_gym/mods/softcapping.py)。
Causal Mask
尽管双向注意力是最简单的,但《Attention is All You Need》论文和大多数LLM在仅解码器设置中使用注意力,其中每个token只能关注其之前的token。人们通常认为这是一个下三角掩码,通过
score_mod
API,它可以表示为:
def causal_mask (score, b, h, q_idx, kv_idx) : return torch.where(q_idx >= kv_idx, score, -float("inf" ))
基本上,如果查询token在键token之后,我们保留分数。否则,我们通过将其设置为-inf来将其掩码掉,从而确保它不会参与softmax计算。
然而,与其他修改相比,掩码是特殊的——如果某些内容被掩码掉,我们可以完全跳过其计算!在这种情况下,因果掩码大约有50%的稀疏性,因此如果不利用这种稀疏性,将会导致2倍的减速。尽管这个
score_mod
足以正确实现因果掩码,但要获得稀疏性的性能优势,还需要另一个概念——
mask_mod
。
Mask Mods
为了利用掩码带来的稀疏性,我们需要做更多的工作。具体来说,通过将
mask_mod
传递给
create_block_mask
,我们可以创建一个
BlockMask
。然后,FlexAttention 可以使用
BlockMask
来利用这种稀疏性!
mask_mod
的签名与
score_mod
非常相似——只是没有分数。特别是
# returns True if this position should participate in the computation mask_mod(b, h, q_idx, kv_idx) => bool
请注意,
score_mod
比
mask_mod
更具表达力。然而,对于掩码操作,建议使用
mask_mod
和
create_block_mask
,因为它们的性能更好。请参阅常见问题解答,了解为什么
score_mod
和
mask_mod
是分开的。
现在,让我们看看如何使用
mask_mod
实现因果掩码。
Causal Mask
from torch.nn.attention.flex_attention import create_block_maskdef causal (b, h, q_idx, kv_idx) : return q_idx >= kv_idx# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them) block_mask = create_block_mask(causal, B=None , H=None , Q_LEN=1024 , KV_LEN=1024 )# In this case, we don't need a score_mod, so we won't pass any in. # However, score_mod can still be combined with block_mask if you need the additional flexibility. flex_attention(query, key, value, block_mask=block_mask)
请注意,
create_block_mask
是一个
相对昂贵的操作
!尽管 FlexAttention 在更改时不需要重新编译,但如果你不注意缓存它,它可能会导致显著的减速(查看常见问题解答以获取最佳实践建议)。
尽管TFlops大致相同,但mask_mod版本的执行时间快了2倍!这表明我们可以利用BlockMask提供的稀疏性,而不会损失硬件效率。
Sliding Window + Causal
由Mistral(https://arxiv.org/abs/2310.06825)推广的滑动窗口注意力(也称为局部注意力)利用了最近token最有用的直觉。特别是,它允许query token仅关注最近的1024个token。这通常与因果注意力一起使用。