专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

图解大模型训练系列:序列并行3,Ring Attention

GiantPandaLLM  · 公众号  · 3D  · 2024-11-08 19:20

正文

请到「今天看啥」查看全文


时,我们用的是 这个矩阵的local max和global sum。

  • 所以,在使用softmax的情况下,我们无法对 做简单的累加。

  • 那么,在分块的情况下,我们到底要采取什么方式更新 呢? 本质上来说,ring attention采用的是和flash attention V2非常相近的 更新方式 ,具体可以参见我之前对Flash Attention V2的解读:

    • Flash Attention V2 :这篇文章的1.2(1)部分详细介绍了 的更新方式

    • Flash Attention V1 :这篇文章第四部分详细介绍了朴素attention->safe softmax -> 分块safe softmax的整个过程,并用递归法证明了 的更新方式,这个证明方法同样可以类推到V2的 上。

    注意,这里不了解 的具体更新方式并不影响下文的阅读。所以,本文不再对 的更新细节和数学推导做更多论述。

    但这边我们额外再关注一点: Ring Attention和Flash Attention V2的 的更新方式非常相近,但不完全相同。 为了更好阐述这一点,我们先来看Flash Attention V2中 的更新方式:

    上图展示的是Flash Attention V2的fwd算法过程,第10行展示了 的更新方式。同时注意到,当我们把outer loop和inner loop全部做完后,在第12行我们又对 做了一次更新,且这个更新是一次性的,同时更新公式中的 和global sum相关。Flash Attention V2为什么要这么做呢?因为:

    • 首先,你当然可以把第12行的更新放到第10行中去做 。也就是对于某个分块 ,我们在逐步更新它对应的 时,我们要考虑到目前为止得到的global sum信息。 什么叫“目前为止得到global sum”信息呢? 例如,当你计算出 时,你会根据它得到一个sum;当你算出 时,你会根据它和 再次得到一个sum;当你算完全部的S分块时,你得到的sum就是真正的global sum了。 所以尽管在这里我们没有给出详细的数学推导,从直觉上也不难理解,我们可以选择在第10行内用“目前为止得到的global sum”做迭代更新,也可以选择在第12行用最终的global sum做一个一次性的更新。Flash Attention V1和ring Attention选择把第12行放入第10行中做,而Flash Attention V2选择把两者拆开。

    • 而把第12行更新从第10行中拆出来的主要原因,是为了在gpu中尽量减少非矩阵乘法的计算量 。这是因为在现代gpu中(比如NV GPU)非矩阵乘法的计算比矩阵乘法慢约16倍。以NV A100来说,fp16/bf16的矩阵乘法计算理论上的最大吞吐是312 TFLOPs/s,但是非矩阵乘法运算仅为19.5TFLOPs/s

    好,到目前为止,我们已经知道在分块的情况下,如何在单GPU上进行Attention计算了,接下来,我们就把这个计算过程拆分到多gpu上,来看看ring attention中的ring是如何运作的。

    三、多gpu:环状通信







    请到「今天看啥」查看全文