专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  图解Vllm ... ·  昨天  
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  2 天前  
51好读  ›  专栏  ›  GiantPandaLLM

Flex Attention API 应用 Notebook 代码速览

GiantPandaLLM  · 公众号  · 3D  · 2024-10-13 18:24

正文

请到「今天看啥」查看全文


我们将暂时离开主题,描述两个关键概念,这些概念对于理解如何获得FlexAttention的最大性能优势非常重要。flex_attention的完整API如下:

flex_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    score_mod: Optional[Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    block_mask: Optional[torch.nn.attention.flex_attention.BlockMask] = None,
    scale: Optional[float] = None,
)

你可能会好奇为什么我们需要同时使用 score_mod block_mask

  • 当你想在注意力权重矩阵中修改分数值时,应该使用 score_mod 函数。
  • 当你想在注意力权重矩阵中掩码分数值时,应该使用 mask_mod 函数,这些分数值独立于分数值本身,仅依赖于位置信息。

注意:任何 block_mask 也可以用 score_mod 表示,但kernel的性能将不是最优的。

让我们通过因果注意力来突出差异。

使用score_mod的实现:

def causal_bias(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

每当你编写一个 score_mod 函数,该函数对某些元素传递原始分数,而对其他元素设置为 -inf 时,你应该可能使用 mask_mod

使用 mask_mod 的实现:

def casual_mask(b,h,q_idx, kv_idx):
    return q_idx >= kv_idx

正如你所见,它们看起来非常相似,都返回标量张量。关键的区别在于:

  • mask_mods 返回布尔张量,其中 True 表示应该计算该分数,而 False 表示我们想要掩码该分数。
  • mask_mods 不接受 score 参数,因为它们在计算过程中不允许依赖实际值。

当我同时使用 score_mod 和 mask_mod 时会发生什么?

score_mod 函数将应用于每个未被掩码的元素。

我有一个 mask mod 函数,如何创建一个 BlockMask?

问得好,读者!除了 flex_attention,我们还提供了一个主要的 API。

create_block_mask(
    mask_mod (Callable): mask_mod function.
    B (int): Batch size.
    H (int): Number of heads.
    Q_LEN (int): Sequence length of query.
    KV_LEN (int): Sequence length of key/value.
    device (str): Device to run the mask creation on.
    KV_BLOCK_SIZE (int): Block size of block mask for each query.
    Q_BLOCK_SIZE (int): Block size of block mask for each key/value.
    _compile (bool): Whether to compile the mask creation.
)

因此,对于上述示例,调用flex_attention的最优性能方式是:

causal_block_mask = create_block_mask(causal_mask, B, H, M, N)
flex_attention(query, key, value, block_mask = causal_block_mask)

B,H,Q_LEN,KV_LEN 分别是 batch_size、num_heads、query_sequence_length 和 key_sequence_length。

为什么两者都有?

纯粹是为了性能。因果掩码实际上非常稀疏。只有注意力分数的下三角部分是重要的。如果不生成BlockMask,我们将需要做两倍的工作!下面我们将比较这两种实现的性能差异。

分数修改示例

让我们探索可以使用FlexAttention API的各种分数修改示例。

图例:我们将打印这些score_mod + mask_fns的稀疏性表示。

任何块的缺失意味着它被完全掩码,实际上不需要计算最终的注意力输出

  • ██ 这个块计算所有查询和键token之间的完全注意力
  • ░░ 这个块部分掩码,一些查询token关注一些键token,但一些被掩码为-inf

全注意力

应用一个“无操作”的分数修改。保持注意力分数不变。

def noop(score, b, h, q_idx, kv_idx):
    return score

test_mask(noop, print_mask=True)

执行后的输出为:

Results for noop:
+---------------+----------------+-------------------+----------------+-------------------+
| Operation     |   FW Time (ms) |   FW FLOPS (TF/s) |   BW Time (ms) |   BW FLOPS (TF/s) |
+===============+================+===================+================+===================+
| causal FA2    |        14.6478 |            150.13 |        41.1986 |            133.44 |
+---------------+----------------+-------------------+----------------+-------------------+
| F.sdpa + mask |        58.8032 |             74.79 |       125.07   |             87.91 |
+---------------+----------------+-------------------+----------------+-------------------+
| flexattention |        27.3449 |            160.84 |        94.4015 |            116.47 |
+---------------+----------------+-------------------+----------------+-------------------+

Block Mask:
None







请到「今天看啥」查看全文