正文
# wrapping each TransformerBlock, then root model # the UX is the same across float8 model and bfloat16 model for transformer_block in model.layers.values(): fully_shard(transformer_block) fully_shard(model)# training loop # ... optimizer.step()# all-reduce AMAX for Float8Linear.weight precompute_float8_dynamic_scale_for_fsdp(model)
用于float8张量子类的FSDP2扩展
: 我们在bfloat16模型和float8模型中保持相同的FSDP2用户体验,因为我们在FSDP2扩展中实现了float8类型转换。float8线性模块的权重是一个知道如何转换为float8的张量子类。我们可以自定义all-gather前后的类型转换逻辑,如下图所示。
fsdp_pre_all_gather (代码(https://github.com/pytorch-labs/float8_experimental/blob/0aca10aced1c4b3abdf00960d83316732cb08ed1/float8_experimental/fsdp_utils.py#L166))
: 根据最新的复制AMAX/缩放因子(需要all-reduce)将bfloat16权重转换为float8权重。注意这里的bfloat16权重是按1/NGPU分片的。由于我们通过all-reduce在所有rank上获得复制的AMAX和缩放因子,在all-gather之前将分片的bfloat16参数转换为float8等同于先all-gather bfloat16参数然后再转换为float8。
fsdp_post_all_gather (代码(https://github.com/pytorch-labs/float8_experimental/blob/0aca10aced1c4b3abdf00960d83316732cb08ed1/float8_experimental/fsdp_utils.py#L196))
: 从all-gather的float8数据和复制的缩放因子构建Float8Tensor,以便在前向和反向中进行float8计算。
性能深入分析
我们讨论float8中的关键优化,以达到相比bfloat16
1.50倍
的加速。
Float8计算 + Bfloat16 All-Gather
(1.40倍加速, 代码(https://github.com/pytorch-labs/float8_experimental/blob/0aca10aced1c4b3abdf00960d83316732cb08ed1/float8_experimental/float8_linear.py#L439-L452)): 当用Float8Linear替换nn.Linear时,可以保持bfloat16权重不变。我们只需将Float8Linear当作普通的nn.Linear处理,并在FSDP2中执行bfloat16 all-gather(流22)。Float8Linear.forward负责bfloat16到float8的类型转换和float8矩阵乘法(流7)。这种方法实现了1.40倍的加速,是展示float8计算重要性的有力基准。然而,它浪费了50%的带宽来传输bfloat16参数,而这些参数最终会在前向过程中被转换为float8。
带独立AMAX All-Reduce的Float8 All-Gather
(在1.40倍基础上+0.02倍, 代码(https://github.com/pytorch/torchtitan/blob/0f70507f1350679428ea64f90bc5a7db17b9c103/torchtitan/float8_linear.py#L96)): 我们在all-gather之前执行float8类型转换以节省50%带宽(流22)。因此,Float8Linear.forward可以直接使用float8权重而无需类型转换(流7)。然而,float8类型转换需要一个全局AMAX(abs(max)的最大值),所以我们需要在N个rank之间all-reduce部分AMAX(一个标量)(流22和35)。每个float8参数需要1次all-reduce。这些小的all-reduce操作降低了整体性能。
组合AMAX AllReduce
(在1.42倍基础上+0.08倍, 代码(https://github.com/pytorch/torchtitan/blob/0f70507f1350679428ea64f90bc5a7db17b9c103/torchtitan/float8_linear.py#L107)): 我们在优化器步骤之后对所有float8参数执行单次all-reduce。因此,我们避免了在FSDP钩子内部的小型all-reduce操作(流47)。我们通过一次性计算所有float8参数的AMAX实现了水平融合。
NCCL和Float8计算之间的SM竞争
: 根据NCCL版本和GPU总SM数量,有时float8计算(流7)中会出现气泡。float8计算(sm90_xmm)和float8 all-gather(ncclDevKernel)都在争夺SM资源。理想情况是始终优先考虑第k层的float8计算而不是第k+1层的float8 all-gather。在这种情况下,如果NCCL使用更少的SM进行较慢的通信或float8计算使用更少的SM。我们发现在基准测试期间将NCCL_MAX_CTAS(https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-max-ctas)设置为16或8对解决竞争很有帮助。