专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  17 小时前  
GiantPandaLLM  ·  【博客转载】CUDA Coalesced ... ·  2 天前  
GiantPandaLLM  ·  【博客转载】C++/CUDA Data ... ·  4 天前  
GiantPandaLLM  ·  【博客转载】CUDA Kernel ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaLLM

PyTorch 原生FP8训练进展

GiantPandaLLM  · 公众号  · 3D  · 2025-01-08 19:57

正文

请到「今天看啥」查看全文


作者:IBM: Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam Meta: Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous

在本博客中,我们将展示如何在保持损失和评估基准一致性的同时,相比 FSDP1 bf16训练 实现高达50%的吞吐量提升。我们通过利用FSDP2、DTensor和torch.compile与torchao的float8线性层更新(计算)以及float8 all_gathers进行权重通信来实现这一提升。我们展示了这些改进在Meta LLaMa模型架构的不同规模上的效果,从1.8B小型模型一直到405B大型模型,使训练速度比以往更快。

我们使用Meta Llama3架构展示这些改进,并在两个规模上进行模型质量研究:8B模型规模的100B tokens训练和70B模型规模的50B tokens训练,这提供了float8和bf16训练损失曲线的精确比较。我们证明了与 bf16 相比,这些模型训练运行的损失曲线达到了相同的损失收敛。此外,我们使用FineWeb-edu数据集训练了一个3B模型到1T tokens,并运行标准评估基准以确保模型质量完整且与bf16运行相当。

在IBM研究院,我们计划采用这些功能进行数据消融实验,以提高在给定GPU预算内可以执行的实验数量。从长远来看,我们将通过更大规模的模型运行来展示 float8 训练的端到端可行性。

什么是Float8?

float8 训练格式是由NVIDIA、ARM和Intel在2022年的一篇论文(https://arxiv.org/abs/2209.05433)中提出的,该论文证明了使用更低精度float8进行训练的可行性,且不会牺牲模型质量。随着NVIDIA Hopper系列等新型GPU的推出,由于原生float8张量核心支持,FP8训练变得可行,有望实现超过2倍的训练吞吐量提升。实现这一承诺面临一些挑战:(i) 在 float8 中启用核心模型操作如 matmul attention , (ii) 在分布式框架中启用 float8 训练, (iii) 在 float8 中启用GPU之间的权重通信。虽然NVIDIA库启用了 float8 matmul ,但后两项是在FSDP2和torchao的最新更新中提供的。







请到「今天看啥」查看全文