专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

【翻译】在FSDP2中开启Float8 All-Gather

GiantPandaLLM  · 公众号  · 3D  · 2024-11-02 22:08

正文

请到「今天看啥」查看全文


# wrapping each TransformerBlock, then root model
# the UX is the same across float8 model and bfloat16 model
for transformer_block in model.layers.values():
    fully_shard(transformer_block)
fully_shard(model)

# training loop
# ...
optimizer.step()
# all-reduce AMAX for Float8Linear.weight
precompute_float8_dynamic_scale_for_fsdp(model)

用于float8张量子类的FSDP2扩展 : 我们在bfloat16模型和float8模型中保持相同的FSDP2用户体验,因为我们在FSDP2扩展中实现了float8类型转换。float8线性模块的权重是一个知道如何转换为float8的张量子类。我们可以自定义all-gather前后的类型转换逻辑,如下图所示。

  • fsdp_pre_all_gather (代码(https://github.com/pytorch-labs/float8_experimental/blob/0aca10aced1c4b3abdf00960d83316732cb08ed1/float8_experimental/fsdp_utils.py#L166)) : 根据最新的复制AMAX/缩放因子(需要all-reduce)将bfloat16权重转换为float8权重。注意这里的bfloat16权重是按1/NGPU分片的。由于我们通过all-reduce在所有rank上获得复制的AMAX和缩放因子,在all-gather之前将分片的bfloat16参数转换为float8等同于先all-gather bfloat16参数然后再转换为float8。
  • fsdp_post_all_gather (代码(https://github.com/pytorch-labs/float8_experimental/blob/0aca10aced1c4b3abdf00960d83316732cb08ed1/float8_experimental/fsdp_utils.py#L196)) : 从all-gather的float8数据和复制的缩放因子构建Float8Tensor,以便在前向和反向中进行float8计算。

性能深入分析

我们讨论float8中的关键优化,以达到相比bfloat16 1.50倍 的加速。

Float8计算 + Bfloat16 All-Gather (1.40倍加速, 代码(https://github.com/pytorch-labs/float8_experimental/blob/0aca10aced1c4b3abdf00960d83316732cb08ed1/float8_experimental/float8_linear.py#L439-L452)): 当用Float8Linear替换nn.Linear时,可以保持bfloat16权重不变。我们只需将Float8Linear当作普通的nn.Linear处理,并在FSDP2中执行bfloat16 all-gather(流22)。Float8Linear.forward负责bfloat16到float8的类型转换和float8矩阵乘法(流7)。这种方法实现了1.40倍的加速,是展示float8计算重要性的有力基准。然而,它浪费了50%的带宽来传输bfloat16参数,而这些参数最终会在前向过程中被转换为float8。

带独立AMAX All-Reduce的Float8 All-Gather (在1.40倍基础上+0.02倍, 代码(https://github.com/pytorch/torchtitan/blob/0f70507f1350679428ea64f90bc5a7db17b9c103/torchtitan/float8_linear.py#L96)): 我们在all-gather之前执行float8类型转换以节省50%带宽(流22)。因此,Float8Linear.forward可以直接使用float8权重而无需类型转换(流7)。然而,float8类型转换需要一个全局AMAX(abs(max)的最大值),所以我们需要在N个rank之间all-reduce部分AMAX(一个标量)(流22和35)。每个float8参数需要1次all-reduce。这些小的all-reduce操作降低了整体性能。

组合AMAX AllReduce (在1.42倍基础上+0.08倍, 代码(https://github.com/pytorch/torchtitan/blob/0f70507f1350679428ea64f90bc5a7db17b9c103/torchtitan/float8_linear.py#L107)): 我们在优化器步骤之后对所有float8参数执行单次all-reduce。因此,我们避免了在FSDP钩子内部的小型all-reduce操作(流47)。我们通过一次性计算所有float8参数的AMAX实现了水平融合。

NCCL和Float8计算之间的SM竞争 : 根据NCCL版本和GPU总SM数量,有时float8计算(流7)中会出现气泡。float8计算(sm90_xmm)和float8 all-gather(ncclDevKernel)都在争夺SM资源。理想情况是始终优先考虑第k层的float8计算而不是第k+1层的float8 all-gather。在这种情况下,如果NCCL使用更少的SM进行较慢的通信或float8计算使用更少的SM。我们发现在基准测试期间将NCCL_MAX_CTAS(https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-max-ctas)设置为16或8对解决竞争很有帮助。







请到「今天看啥」查看全文