正文
相乘,得到
q/k/v_chunk = (N/p, d)
(3)针对q/k/v_chunk,所有卡间做一次All2All通讯,使得每张卡拿到所有seq的某1个head的q/k/v_chunk
-
做这个All2All通讯前,每张卡上维护的
q/k/v_chunk = (N/p, d)
,可以理解成某个seq_chunk所有head的qkv值
-
做这个All2All通讯后,每张卡上维护的
q/k/v_chunk = (N, d/P)
,可以理解成所有seq的某个head的qkv值。
我们以q_chunk为例,来具体看All2All是怎么实现这一点的(下图根据ulysses源码进行绘制,做了一点简化)
-
如上图所示,这里我们假设有4块卡(4个head),则最终我们希望gpu0算head0的结果,gpu1算head1的结果...以此类推。我们用不同颜色的矩形表示计算不同head需要用到的q数据。
-
我们从上图的最左侧位于gpu0上的q0看起,它表示seq_chunk0的q结果,尺寸为(N/P, d)。不难理解,如果我们将q0沿着d维度切成P块,那么每一块就表示为了计算出对应的head所需要的q结果。其余gpu上的q_chunk也是类推。
-
现在我们执行All2All算法,你可以将它理解成是一种“转置式”地通信方法
:结合上图我们可以发现,各块卡第1列蓝色块现在都跑去gpu0,第2列绿色块现在都跑去gpu1...这就是我们说的“转置”的含义。
-
All2All结束后,我们还以gpu0为例,
它上面拥有P块(N/P, d/P)数据,表示所有seq在head0上的q结果,我们将其稍作reshape后,每块卡上最终维护的q_chunk就变成(N, d/P)。
每块卡上的k/v_chunk也是同理进行All2All通讯。
(4)每张卡拿到所有seq的某1个head的q/k/v_chunk后,我们正常执行Attention计算
,最终每张卡上产出结果
chunk,尺寸为
(N, d/P)
(5)针对
chunk,所有卡间再做1次All2All通讯,最终单卡上维护的P chunk尺寸又变回(N/P, d)
。
这个All2All过程可以理解成是先前描述的All2All的反操作,作用过程相似,这里不再赘述。
(6)
单张卡上拥有完整的
矩阵,我们将P chunk和它相乘,得到最后的输出O chunk,尺寸为
(N/P, d)
(7)
进入MLP层,由于在MLP层中,不涉及token和token之间的相关性计算,所以各seq_chunk块可以独自计算。
(8)
重复上述过程,直到算到Loss 为止。
-
这里我初步判定,每张卡上算出的Loss应该就是这块卡所维护的那个seq_chunk的Loss。因为我粗看了一遍ulysses的代码,发现目前它的核心是单独设计了一个能实现sp并行的DistributionAttention的模块,然后用这个模块替换掉之前的Attention Module,通过这样一个简单的替换实现了ulysses的基本功能。再考虑到seq_chunk在MLP计算时的独立性和数据并行的特性,最终单卡Loss应该就是seq_chunk Loss,这也意味着sp组的梯度需要做AllReduce通讯,这个我们放在后面对ulysses的通讯量分析中再说。
二、Megatron VS Ulysses
不难发现,Ulysses和Megatron在分布式计算attention上有某些相似之处:
-
Megatron通过tp
,显式地把Wq, Wk, Wv切分开,然后每张卡上计算所有seq的某个head的结果。
-
Ulysses通过sp+all2all
,在每张卡完整保存Wq, Wk, Wv的前提下,让每张卡上计算所有seq的某个head的结果。
那么在实现相似功能的情况下,
Ulysses提出的一个重要卖点是:我的通讯量低
。所以接下来,就让我们来详细分析这一点。
(⚠️⚠️⚠️:如果看到下文时,发现对通讯量、激活值等等计算有疑问的朋友,可以先看这篇写
Megatron SP的文章
。)
2.1 Megatron通讯量
上图展示了megatron tp + sp下的整体运作流程。