专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

CUDA-MODE 课程笔记 第13课:Ring Attention

GiantPandaLLM  · 公众号  · 3D  · 2024-09-29 23:26

正文

请到「今天看啥」查看全文


  • 用一个绿色方框表示,大小为 s × s
  • s 代表序列长度
  • 内存复杂度说明:
    • 原文:"Memory complexity of naive attention is quadratic with sequence length (score matrix & softmax output)."
    • 翻译:朴素注意力机制的内存复杂度与序列长度呈二次方关系(score矩阵和softmax输出)。

    这张Slides讨论了模型大小和上下文长度对每个token的FLOPS(浮点运算次数)缩放的影响。主要内容如下:

    • 标题:情况有多糟?每个Token的FLOPS缩放
    • 热力图:
      • 横轴:上下文长度(从2x到32768x)
      • 纵轴:模型大小(从7B到1TB)
      • 数值:每个单元格中的数字表示相对于4k上下文大小的FLOPS成本比率
    • 关键发现:
      • "令人惊讶的是: 随着模型大小的增加,成本比率反而降低 "
    • FLOPS计算公式:
      • FLOPS = 24sh² + 4s²h
      • (s=序列长度,h=隐藏维度)
      • 当h为常数时,复杂度为O(s²)
    • 结论:
      • "序列长度最终会成为瓶颈 - 但可能比你想象的要晚"

    来源:Ring Attention,附录D。上面的公式是针对FFN,这里bs=1。具体的公式推导看下图:

    这张slides描述了计算Softmax的挑战。Softmax操作需要在分数矩阵(score matrix)的完整行上进行,这个分数矩阵是通过 (Q是Query矩阵,K是Key矩阵的转置)计算得到的。Softmax的输出依赖于分母中的和,也就是所有输入值指数和的计算。为了在FlashAttention和RingAttention算法中应用Softmax,必须“分块”或“在线”地计算Softmax,即只处理部分和,这样可以更高效地计算出结果。

    这张Slides开始介绍如何通过Python中的PyTorch库定义和验证一个简单的Softmax函数,并逐步过渡到Log-Sum-Exp的更新。这里展示了如何用Python代码定义一个朴素的Softmax函数。这个函数接受一个PyTorch张量作为输入,并计算Softmax值。接下来,展示了如何将自定义的Softmax函数与官方的PyTorch torch.softmax() 函数进行比较。通过生成一个随机张量 x ,分别计算官方Softmax结果 a 和自定义版本 b 。使用 torch.allclose() 函数验证两个输出是否接近。

    slides标题提到“Naive & Numerical unstable”(朴素且数值不稳定),表示当前定义的朴素Softmax函数在某些输入情况下会出现问题。slides显示了一个具体的例子,代码使用了一个随机生成的PyTorch张量x,并将其乘以100传入到朴素的naive_softmax()函数中。结果输出中显示张量中的某些值变成了nan(Not a Number),这表明数值溢出或不稳定。

    我们的目标是将Softmax运算分块处理(breaking softmax() into chunks)。右侧文字指出,虽然可以将向量分块并分别计算Softmax,但最终问题是如何从分块结果 s1 s2 重构出完整的target结果。这也是下一步需要解决的核心问题。

    这张幻slides讲解了如何通过“sum exp”(指数和)撤销Softmax的归一化,从而将分块计算的结果合并。首先回顾了上一个slides中的问题:Softmax输出通过除以 x.exp().sum() 来归一化。为了将多个分块的结果合并,我们需要撤销这种归一化。

    Slides右侧的代码显示了如何通过分块的指数和来进行修正。 x1.exp().sum() x2.exp().sum() 分别计算两个分块 x1 x2 的指数和,命名为 se_x1 se_x2 。然后,分别对分块结果 s1 s2 进行修正,计算公式如slides的代码所示。修正后的结果使用 torch.cat() 函数合并,得到完整的Softmax结果。







    请到「今天看啥」查看全文