专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  图解Vllm ... ·  14 小时前  
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

【翻译】【PyTorch 奇技淫巧】FlexAttetion 基于Triton打造灵活度拉满的Att...

GiantPandaLLM  · 公众号  · 3D  · 2024-10-08 21:14

正文

请到「今天看啥」查看全文


Alibi 与相对位置编码类似,但有一个例外——它有一个通常预先计算的per-head因子。

alibi_bias = generate_alibi_bias() # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
    bias = alibi_bias[h] * (q_idx - kv_idx)
    return score + bias

这展示了 torch.compile 提供的一个有趣的灵活性——即使 alibi_bias 没有被显式地作为输入传递进来,我们也可以从中加载数据!生成的 Triton kernel将计算从 alibi_bias 张量中正确加载的数据并将其融合。请注意,即使重新生成 alibi_bias ,我们也不需要重新编译。

Soft-capping

Soft-capping 是一种在 Gemma2(https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) 和 Grok-1 中使用的技术,用于防止 logits 过度增长。在 FlexAttention 中,它看起来像这样:

softcap = 20
def soft_cap(score, b, h, q_idx, kv_idx):
    score = score / softcap
    score = torch.tanh(score)
    score = score * softcap
    return score

请注意,我们在这里还自动从正向pass生成反向pass。此外,尽管此实现语义上是正确的,但由于性能原因,我们可能希望在这种情况下使用 tanh 近似。有关更多详细信息,请参见 attention-gym(https://github.com/pytorch-labs/attention-gym/blob/main/attn_gym/mods/softcapping.py)。

Causal Mask

尽管双向注意力是最简单的,但《Attention is All You Need》论文和大多数LLM在仅解码器设置中使用注意力,其中每个token只能关注其之前的token。人们通常认为这是一个下三角掩码,通过 score_mod API,它可以表示为:

def causal_mask(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

基本上,如果查询token在键token之后,我们保留分数。否则,我们通过将其设置为-inf来将其掩码掉,从而确保它不会参与softmax计算。

然而,与其他修改相比,掩码是特殊的——如果某些内容被掩码掉,我们可以完全跳过其计算!在这种情况下,因果掩码大约有50%的稀疏性,因此如果不利用这种稀疏性,将会导致2倍的减速。尽管这个 score_mod 足以正确实现因果掩码,但要获得稀疏性的性能优势,还需要另一个概念—— mask_mod

Mask Mods

为了利用掩码带来的稀疏性,我们需要做更多的工作。具体来说,通过将 mask_mod 传递给 create_block_mask ,我们可以创建一个 BlockMask 。然后,FlexAttention 可以使用 BlockMask 来利用这种稀疏性!

mask_mod 的签名与 score_mod 非常相似——只是没有分数。特别是

# returns True if this position should participate in the computation
mask_mod(b, h, q_idx, kv_idx) => bool

请注意, score_mod mask_mod 更具表达力。然而,对于掩码操作,建议使用 mask_mod create_block_mask ,因为它们的性能更好。请参阅常见问题解答,了解为什么 score_mod mask_mod 是分开的。

现在,让我们看看如何使用 mask_mod 实现因果掩码。

Causal Mask

from torch.nn.attention.flex_attention import create_block_mask

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them) 
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=1024, KV_LEN=1024)
# In this case, we don't need a score_mod, so we won't pass any in.
# However, score_mod can still be combined with block_mask if you need the additional flexibility.
flex_attention(query, key, value, block_mask=block_mask)

请注意, create_block_mask 是一个 相对昂贵的操作 !尽管 FlexAttention 在更改时不需要重新编译,但如果你不注意缓存它,它可能会导致显著的减速(查看常见问题解答以获取最佳实践建议)。

尽管TFlops大致相同,但mask_mod版本的执行时间快了2倍!这表明我们可以利用BlockMask提供的稀疏性,而不会损失硬件效率。

Sliding Window + Causal

由Mistral(https://arxiv.org/abs/2310.06825)推广的滑动窗口注意力(也称为局部注意力)利用了最近token最有用的直觉。特别是,它允许query token仅关注最近的1024个token。这通常与因果注意力一起使用。







请到「今天看啥」查看全文