专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

图解Megatron TP中的计算通信overlap

GiantPandaLLM  · 公众号  · 3D  · 2025-01-04 22:00

主要观点总结

本文探索了Megatron中实现计算通信overlap的方法,具体涉及Megatron的dp、tp和pp部分,特别是tp部分(即megatron sp-tp)。文章介绍了在tp中各个步骤的计算和通信流程,以及如何通过p2p ring exchange、pipeline chunk等方法实现计算和通信的串行overlap。此外,文章还介绍了如何通过设置计算流和通信流实现并行overlap,即bulk overlap。最后,文章总结了本文的主要内容和参考资源。

关键观点总结

关键观点1: Megatron中计算通信overlap的探索

介绍Megatron中计算通信overlap的重要性和背景。

关键观点2: Megatron的dp、tp和pp部分简介

概述Megatron中这三个部分的基本功能和在计算通信中的作用。

关键观点3: tp中的计算和通信流程

详细描述tp中各个步骤的计算和通信流程,包括all-gather、reduce-scatter等。

关键观点4: 串行overlap的实现方法

介绍如何通过p2p ring exchange和pipeline chunk等方法实现计算和通信的串行overlap。

关键观点5: 并行overlap的实现方法

介绍如何通过设置计算流和通信流实现并行overlap,即bulk overlap。

关键观点6: 总结

概括文章的主要内容和结论,以及参考资源。


正文

请到「今天看啥」查看全文


假设我们采取的是最朴素的,没有任何overlap的策略,那么红框中的计算流程应该是下图这样的,这里假设tp_size = 2:

如上图所示,我们有2张gpu(tp_size = 2):

  • 在all-gather开始前,gpu0上存储着输入A0和模型分块B0,gpu1上存储着输入A1和模型分块B1。这里的B就对应着上图中的fc1。
  • 在朴素的all-gather中,我们先对输入A矩阵做all-gather,之后两张卡上的数据都变成[A0, A1]
  • 然后再各自个和B矩阵(fc1)相乘,得到最终的结果。 不难发现,这里我们需要先等输入数据A到齐,然后才可以开始计算,也就是没有实现任何的计算通信overlap。

针对这张图,我们额外说明一点:例如[A0, A1]这样的形式,不代表A一定就是按照列切割的,只代表我们以分块的视角看待A。而Enisum可理解为一种自适应式的矩阵乘。因此我们要根据实际应用的场景来理解这张图,后文同理。

2.2 all-gather overlap p2p

现在我们引入计算通信overlap,流程如下图所示:

  • 在最开始阶段,gpu0上存放着输入A0和模型分块B0,gpu1上存放着输入A1和模型分块B1。
  • 现在开始操作:
    • 在gpu0上,我们先把A0发送到gpu1,于此同时开始做gemm(A0, B0),以便得到C00,实现计算通讯overlap
    • 在gpu1上,我们先把A1发送到gpu0,于此同时开始做gemm(A1, B1),以便得到C11,实现计算通讯overlap
    • 等gpu0计算完C00,并收到A1后,它就可以继续gemm(A1, B0),以便得到C10;gpu1也是同理
  • 在overlap下,我们无需等到输入数据all-gather到齐后再进行计算,这样就可以减少整体的运行时间。

以上展示了2卡情况下的all-gather overlap,在多卡情况下也是同理,整体流程如下图所示:

  • partition即为卡,iteration则为每轮迭代,每轮迭代里包含了计算-通信的overlap。partition中的Di表示目前正在使用哪块输入做计算。

  • 从图中我们可以发现,这里采取的是p2p ring exchange的方式,也就是每张卡只和自己相邻的2张卡做数据的收-发。

  • 例如,在iteration0上时,每张卡做计算时,都用自己维护的那份数据做计算,所以这里Di和partition_i的下标是一一对应的。同时,每张卡会和相邻的2张卡做数据收发。例如partition2会把自己的数据D2发送给partition1,并从partition3上接受D3。

  • 再如,在iteration1上时,partition2就用自己收到的D3做计算了,同时它准备把D3发送给partition1,并从partition3上接收D0。以此类推。

相关的代码实践在TE仓库的CommOverlapP2PBase类下,大家可以自行阅读。注意代码里的A=weight, B=input,后文也是同理。

https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp#L561

三、tp_comm_overlap_rs







请到「今天看啥」查看全文