主要观点总结
本文探索了Megatron中实现计算通信overlap的方法,具体涉及Megatron的dp、tp和pp部分,特别是tp部分(即megatron sp-tp)。文章介绍了在tp中各个步骤的计算和通信流程,以及如何通过p2p ring exchange、pipeline chunk等方法实现计算和通信的串行overlap。此外,文章还介绍了如何通过设置计算流和通信流实现并行overlap,即bulk overlap。最后,文章总结了本文的主要内容和参考资源。
关键观点总结
关键观点1: Megatron中计算通信overlap的探索
介绍Megatron中计算通信overlap的重要性和背景。
关键观点2: Megatron的dp、tp和pp部分简介
概述Megatron中这三个部分的基本功能和在计算通信中的作用。
关键观点3: tp中的计算和通信流程
详细描述tp中各个步骤的计算和通信流程,包括all-gather、reduce-scatter等。
关键观点4: 串行overlap的实现方法
介绍如何通过p2p ring exchange和pipeline chunk等方法实现计算和通信的串行overlap。
关键观点5: 并行overlap的实现方法
介绍如何通过设置计算流和通信流实现并行overlap,即bulk overlap。
关键观点6: 总结
概括文章的主要内容和结论,以及参考资源。
正文
假设我们采取的是最朴素的,没有任何overlap的策略,那么红框中的计算流程应该是下图这样的,这里假设tp_size = 2:
如上图所示,我们有2张gpu(tp_size = 2):
-
在all-gather开始前,gpu0上存储着输入A0和模型分块B0,gpu1上存储着输入A1和模型分块B1。这里的B就对应着上图中的fc1。
-
在朴素的all-gather中,我们先对输入A矩阵做all-gather,之后两张卡上的数据都变成[A0, A1]
-
然后再各自个和B矩阵(fc1)相乘,得到最终的结果。 不难发现,这里我们需要先等输入数据A到齐,然后才可以开始计算,也就是没有实现任何的计算通信overlap。
针对这张图,我们额外说明一点:例如[A0, A1]这样的形式,不代表A一定就是按照列切割的,只代表我们以分块的视角看待A。而Enisum可理解为一种自适应式的矩阵乘。因此我们要根据实际应用的场景来理解这张图,后文同理。
2.2 all-gather overlap p2p
现在我们引入计算通信overlap,流程如下图所示:
-
在最开始阶段,gpu0上存放着输入A0和模型分块B0,gpu1上存放着输入A1和模型分块B1。
-
-
在gpu0上,我们先把A0发送到gpu1,于此同时开始做gemm(A0, B0),以便得到C00,实现计算通讯overlap
-
在gpu1上,我们先把A1发送到gpu0,于此同时开始做gemm(A1, B1),以便得到C11,实现计算通讯overlap
-
等gpu0计算完C00,并收到A1后,它就可以继续gemm(A1, B0),以便得到C10;gpu1也是同理
-
在overlap下,我们无需等到输入数据all-gather到齐后再进行计算,这样就可以减少整体的运行时间。
以上展示了2卡情况下的all-gather overlap,在多卡情况下也是同理,整体流程如下图所示:
-
partition即为卡,iteration则为每轮迭代,每轮迭代里包含了计算-通信的overlap。partition中的Di表示目前正在使用哪块输入做计算。
-
从图中我们可以发现,这里采取的是p2p ring exchange的方式,也就是每张卡只和自己相邻的2张卡做数据的收-发。
-
例如,在iteration0上时,每张卡做计算时,都用自己维护的那份数据做计算,所以这里Di和partition_i的下标是一一对应的。同时,每张卡会和相邻的2张卡做数据收发。例如partition2会把自己的数据D2发送给partition1,并从partition3上接受D3。
-
再如,在iteration1上时,partition2就用自己收到的D3做计算了,同时它准备把D3发送给partition1,并从partition3上接收D0。以此类推。
相关的代码实践在TE仓库的CommOverlapP2PBase类下,大家可以自行阅读。注意代码里的A=weight, B=input,后文也是同理。
https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp#L561
三、tp_comm_overlap_rs