正文
作者:IBM: Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam
Meta: Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous
在本博客中,我们将展示如何在保持损失和评估基准一致性的同时,相比
FSDP1 bf16训练
实现高达50%的吞吐量提升。我们通过利用FSDP2、DTensor和torch.compile与torchao的float8线性层更新(计算)以及float8 all_gathers进行权重通信来实现这一提升。我们展示了这些改进在Meta LLaMa模型架构的不同规模上的效果,从1.8B小型模型一直到405B大型模型,使训练速度比以往更快。
我们使用Meta Llama3架构展示这些改进,并在两个规模上进行模型质量研究:8B模型规模的100B tokens训练和70B模型规模的50B tokens训练,这提供了float8和bf16训练损失曲线的精确比较。我们证明了与
bf16
相比,这些模型训练运行的损失曲线达到了相同的损失收敛。此外,我们使用FineWeb-edu数据集训练了一个3B模型到1T tokens,并运行标准评估基准以确保模型质量完整且与bf16运行相当。
在IBM研究院,我们计划采用这些功能进行数据消融实验,以提高在给定GPU预算内可以执行的实验数量。从长远来看,我们将通过更大规模的模型运行来展示
float8
训练的端到端可行性。
什么是Float8?
float8
训练格式是由NVIDIA、ARM和Intel在2022年的一篇论文(https://arxiv.org/abs/2209.05433)中提出的,该论文证明了使用更低精度float8进行训练的可行性,且不会牺牲模型质量。随着NVIDIA Hopper系列等新型GPU的推出,由于原生float8张量核心支持,FP8训练变得可行,有望实现超过2倍的训练吞吐量提升。实现这一承诺面临一些挑战:(i) 在
float8
中启用核心模型操作如
matmul
和
attention
,
(ii) 在分布式框架中启用
float8
训练,
(iii) 在
float8
中启用GPU之间的权重通信。虽然NVIDIA库启用了
float8
matmul
,但后两项是在FSDP2和torchao的最新更新中提供的。