正文
我们将暂时离开主题,描述两个关键概念,这些概念对于理解如何获得FlexAttention的最大性能优势非常重要。flex_attention的完整API如下:
flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
score_mod: Optional[Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
block_mask: Optional[torch.nn.attention.flex_attention.BlockMask] = None,
scale: Optional[float] = None,
)
你可能会好奇为什么我们需要同时使用
score_mod
和
block_mask
。
-
当你想在注意力权重矩阵中修改分数值时,应该使用
score_mod
函数。
-
当你想在注意力权重矩阵中掩码分数值时,应该使用
mask_mod
函数,这些分数值独立于分数值本身,仅依赖于位置信息。
注意:任何
block_mask
也可以用
score_mod
表示,但kernel的性能将不是最优的。
让我们通过因果注意力来突出差异。
使用score_mod的实现:
def causal_bias(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
每当你编写一个
score_mod
函数,该函数对某些元素传递原始分数,而对其他元素设置为 -inf 时,你应该可能使用
mask_mod
。
使用
mask_mod
的实现:
def casual_mask(b,h,q_idx, kv_idx):
return q_idx >= kv_idx
正如你所见,它们看起来非常相似,都返回标量张量。关键的区别在于:
-
mask_mods
返回布尔张量,其中
True
表示应该计算该分数,而
False
表示我们想要掩码该分数。
-
mask_mods
不接受
score
参数,因为它们在计算过程中不允许依赖实际值。
当我同时使用 score_mod 和 mask_mod 时会发生什么?
score_mod 函数将应用于每个未被掩码的元素。
我有一个 mask mod 函数,如何创建一个 BlockMask?
问得好,读者!除了 flex_attention,我们还提供了一个主要的 API。
create_block_mask(
mask_mod (Callable): mask_mod function.
B (int): Batch size.
H (int): Number of heads.
Q_LEN (int): Sequence length of query.
KV_LEN (int): Sequence length of key/value.
device (str): Device to run the mask creation on.
KV_BLOCK_SIZE (int): Block size of block mask for each query.
Q_BLOCK_SIZE (int): Block size of block mask for each key/value.
_compile (bool): Whether to compile the mask creation.
)
因此,对于上述示例,调用flex_attention的最优性能方式是:
causal_block_mask = create_block_mask(causal_mask, B, H, M, N)
flex_attention(query, key, value, block_mask = causal_block_mask)
B,H,Q_LEN,KV_LEN 分别是 batch_size、num_heads、query_sequence_length 和 key_sequence_length。
为什么两者都有?
纯粹是为了性能。因果掩码实际上非常稀疏。只有注意力分数的下三角部分是重要的。如果不生成BlockMask,我们将需要做两倍的工作!下面我们将比较这两种实现的性能差异。
分数修改示例
让我们探索可以使用FlexAttention API的各种分数修改示例。
图例:我们将打印这些score_mod + mask_fns的稀疏性表示。
任何块的缺失意味着它被完全掩码,实际上不需要计算最终的注意力输出
-
██ 这个块计算所有查询和键token之间的完全注意力
-
░░ 这个块部分掩码,一些查询token关注一些键token,但一些被掩码为-inf
全注意力
应用一个“无操作”的分数修改。保持注意力分数不变。
def noop(score, b, h, q_idx, kv_idx):
return score
test_mask(noop, print_mask=True)
执行后的输出为:
Results for noop:
+---------------+----------------+-------------------+----------------+-------------------+
| Operation | FW Time (ms) | FW FLOPS (TF/s) | BW Time (ms) | BW FLOPS (TF/s) |
+===============+================+===================+================+===================+
| causal FA2 | 14.6478 | 150.13 | 41.1986 | 133.44 |
+---------------+----------------+-------------------+----------------+-------------------+
| F.sdpa + mask | 58.8032 | 74.79 | 125.07 | 87.91 |
+---------------+----------------+-------------------+----------------+-------------------+
| flexattention | 27.3449 | 160.84 | 94.4015 | 116.47 |
+---------------+----------------+-------------------+----------------+-------------------+
Block Mask:
None