正文
时,我们用的是
这个矩阵的local max和global sum。
所以,在使用softmax的情况下,我们无法对
做简单的累加。
那么,在分块的情况下,我们到底要采取什么方式更新
呢?
本质上来说,ring attention采用的是和flash attention V2非常相近的
更新方式
,具体可以参见我之前对Flash Attention V2的解读:
-
-
Flash Attention V1
:这篇文章第四部分详细介绍了朴素attention->safe softmax -> 分块safe softmax的整个过程,并用递归法证明了
的更新方式,这个证明方法同样可以类推到V2的
上。
注意,这里不了解
的具体更新方式并不影响下文的阅读。所以,本文不再对
的更新细节和数学推导做更多论述。
但这边我们额外再关注一点:
Ring Attention和Flash Attention V2的
的更新方式非常相近,但不完全相同。
为了更好阐述这一点,我们先来看Flash Attention V2中
的更新方式:
上图展示的是Flash Attention V2的fwd算法过程,第10行展示了
的更新方式。同时注意到,当我们把outer loop和inner loop全部做完后,在第12行我们又对
做了一次更新,且这个更新是一次性的,同时更新公式中的
和global sum相关。Flash Attention V2为什么要这么做呢?因为:
-
首先,你当然可以把第12行的更新放到第10行中去做
。也就是对于某个分块
,我们在逐步更新它对应的
时,我们要考虑到目前为止得到的global sum信息。
什么叫“目前为止得到global sum”信息呢?
例如,当你计算出
时,你会根据它得到一个sum;当你算出
时,你会根据它和
再次得到一个sum;当你算完全部的S分块时,你得到的sum就是真正的global sum了。
所以尽管在这里我们没有给出详细的数学推导,从直觉上也不难理解,我们可以选择在第10行内用“目前为止得到的global sum”做迭代更新,也可以选择在第12行用最终的global sum做一个一次性的更新。Flash Attention V1和ring Attention选择把第12行放入第10行中做,而Flash Attention V2选择把两者拆开。
-
而把第12行更新从第10行中拆出来的主要原因,是为了在gpu中尽量减少非矩阵乘法的计算量
。这是因为在现代gpu中(比如NV GPU)非矩阵乘法的计算比矩阵乘法慢约16倍。以NV A100来说,fp16/bf16的矩阵乘法计算理论上的最大吞吐是312 TFLOPs/s,但是非矩阵乘法运算仅为19.5TFLOPs/s
好,到目前为止,我们已经知道在分块的情况下,如何在单GPU上进行Attention计算了,接下来,我们就把这个计算过程拆分到多gpu上,来看看ring attention中的ring是如何运作的。
三、多gpu:环状通信