专栏名称: 大模型智能
机器学习算法、深度学习算法、自然语言处理等干货知识集中营
目录
相关文章推荐
北京日报  ·  中共中央举行纪念陈云同志诞辰120周年座谈会 ... ·  23 小时前  
药通社  ·  【议程首发】吸入制剂大会 ... ·  昨天  
北京日报  ·  韩军:朝鲜停止对韩广播 ·  2 天前  
51好读  ›  专栏  ›  大模型智能

快手二面拷打:训练100B模型要多少显存?

大模型智能  · 公众号  ·  · 2025-05-12 00:42

正文

请到「今天看啥」查看全文


临时变量(temporary)、未知数据(unknown)
  • 其他(框架): 自动梯度(autograd_detail)
  • 其中“未命名数据”来源可能是用户创建的一些临时变量,这些变量未参与图的计算过程,所以未被统计;或者是一些未被框架跟踪(tracing)到的数据。“自动梯度数据"是在反向传播求解梯度时产生的一些变量;
    我们在显存计算时会发现“为什么有时显存估算值和实际测量值相差较大?”
    其中一个可能的原因是:未知的数据太大。即显存中可估算值占比相对较小,其它不可估算值的数据占比较大,导致计算值和实际值差距较大( 误差可超过 30% ),比如估算得到的显存消耗为 50GB,而实际测试达到了 75GB。
    如下图是运行一个 LLM 模型采集的一些过程数据,可以看到 unknown 占比有时能达到 30%。
    图片

    不同时刻显存的占比变化


    02
    计算公式

    2.1 训练场景

    训练显存消耗(可估算部分)主要包括: 模型参数(Model)+ 优化器状态(Optimizer status)+梯度值(Gradient)+激活值(Activation)

    根据数值的变化,可将显存消耗分为静态/动态值。训练过程中,模型参数、优化器状态一般不会变化,这两部分归属于静态值;激活值、梯度值会随着计算过程发生变化,将它们归类到动态值。

    图片

    下面主要来看一下这四种类型值的估算方法:

    2.1.1 模型显存(Model Memory)

    模型自身所占用的显存大小与 参数量、参数类型 相关。常见类型 fp32、fp16/bf16、还有 int8、fp8 等。

    图片

    关于模型保存的大小估算方法: 存储 checkpoint(ckpt)时仅考虑模型本身,只要将显存上模型内容存储到磁盘中。

    举例:以 1B(billion)模型为例,若采用 fp32 类型将其存储在磁盘上,其大小为:

    图片

    1B 模型需要 3.725GB 存储空间,进一步近似认为 1B 4GB,可方便作存储的估算推导,如 LLama13b,大约需要 52GB 存储空间。

    注意:混合精度(Mixed-precision)最后存储的类型也是 fp32,公式也适合混合精度。

    2.1.2 优化器状态(Optimizer status)

    在 LLM 中常见的优化器是 Adam,优化器中每个参数需要一个 Momentum 和一个 Variance 状态参数,在混合精度训练中 Adam 还有一份 模型参数副本

    Adam 参数器状态值计算公式(单位 GB):

    图片

    其中(4+4+4)的内容:

    • 模型副本 4 Bytes
    • Momentum 参数 4 Bytes
    • Variance 参数 4 Bytes
    • 如果是 8 位优化器,则计算变为:

    • 模型副本 4 Bytes
    • Momentum 参数 1Byte
    • Variance 参数 1Byte
    图片

    2.1.3 梯度值(Gradient)

    梯度值与模型数据类型保持一致,计算如下(单位 GB):

    图片

    2.1.4 激活值(Activation)

    激活值的大小跟模型参数、重计算、并行策略等相关,这里我们参考 Megtron 论文里面给的计算公式,来求解激活值所占用的显存大小。

    图片

    2.2 训练的并行计算公式

    目前,单卡的物理显存基本不能满足大模型的训练需求,一般会采用模型并行方式来降低单卡显存消耗。

    常见的几种方法: TP/SP/PP/Zero/重计算 ,这些方法出现在 DeepSpeed、Megtron 等并行框架中,目标都是让 GPU 能够装下更大的模型。

    其中:

    • TP(TensorParallel): tensor 并行;
    • SP(SequenceParallel): 序列并行;
    • PP(PipelineParallel):






    请到「今天看啥」查看全文