专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  Meta Shuffling的MoE ... ·  3 天前  
GiantPandaLLM  ·  [vLLM实践][算子] ... ·  6 天前  
GiantPandaLLM  ·  MetaShuffling:Meta的Fus ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaLLM

SGLang MLA 实现解析

GiantPandaLLM  · 公众号  · 3D  · 2025-03-09 23:04

正文

请到「今天看啥」查看全文


, False )
scaling_factor = rope_scaling[ "factor" ]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

# 初始化RadixAttention,用于高效的注意力计算
# TODO, support head_size 192
self.attn = RadixAttention(
self.num_local_heads,
256 , # 固定的内部维度,用于计算效率
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
)

def forward (
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
)
-> torch.Tensor:

"""
注意力层的前向传播函数。

参数:
positions: 位置编码张量,用于RoPE计算
hidden_states: 输入隐藏状态
forward_batch: 前向计算批次信息

返回:
output: 注意力层的输出
"""

# 计算查询向量Q
if self.q_lora_rank is not None :
# 使用两阶段投影计算Q
# 第一阶段:hidden_states -> q_lora_rank
q = self.q_a_proj(hidden_states)[ 0 ]
# 对第一阶段输出进行归一化
q = self.q_a_layernorm(q)
# 第二阶段:q_lora_rank -> num_heads * qk_head_dim,并重塑为多头形式
q = self.q_b_proj(q)[ 0 ].view( -1 , self.num_local_heads, self.qk_head_dim)
else :
# 直接投影计算Q,并重塑为多头形式
q = self.q_proj(hidden_states)[ 0 ].view(
-1 , self.num_local_heads, self.qk_head_dim
)

# 将Q分为不使用位置编码的部分和使用位置编码的部分
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim= -1 )

# 计算KV的第一阶段投影
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[ 0 ]
# 分离KV的第一阶段输出和用于RoPE的部分
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim= -1 )
# 为后续处理增加维度
latent_cache = latent_cache.unsqueeze( 1 )
# 对KV的第一阶段输出进行归一化
kv_a = self.kv_a_layernorm(kv_a.contiguous())
# 计算KV的第二阶段投影
kv = self.kv_b_proj(kv_a)[ 0 ]
# 重塑为多头形式
kv = kv.view( -1 , self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
# 分离K的不使用位置编码部分和V
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim= -1 )
# 获取K的使用位置编码部分
k_pe = latent_cache[:, :, self.kv_lora_rank :]

# 应用RoPE到Q和K的位置编码部分
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
# 将处理后的位置编码部分放回Q
q[..., self.qk_nope_head_dim :] = q_pe

# 构建完整的K,包括不使用位置编码的部分和使用位置编码的部分
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe

# 将Q、K、V填充到固定维度256(RadixAttention的内部维度),并重塑为适合注意力计算的形式
q = torch.nn.functional.pad(q, [ 0 , 256 - self.qk_head_dim], value= 0 ).view(
-1 , self.num_local_heads * 256
)
k = torch.nn.functional.pad(k, [ 0 , 256 - self.qk_head_dim], value= 0 ).view(
-1 , self.num_local_heads * 256
)
v = torch.nn.functional.pad(v, [ 0 , 256 - self.v_head_dim], value= 0 ).view(
-1 , self.num_local_heads * 256
)

# 执行注意力计算
attn_output = self.attn(q, k, v, forward_batch)

# 重塑注意力输出,并只保留有效的V维度部分
attn_output = attn_output.view( -1 , self.num_local_heads, 256 )[
..., : self.v_head_dim
].reshape( -1 , self.num_local_heads * self.v_head_dim)

# 通过输出投影将注意力输出投影回原始隐藏层维度
output, _ = self.o_proj(attn_output)

return output

对于 DeepseekV2Attention 类来说,和 DeepSeek V2/V3 的 HuggingFace 提供的 MLA 实现一样,这里的使用的KV Cache实际上是解压缩之后的MHA KV Cache的格式,不是缓存的Latent,并没有实现MLA的缓存节省效果。

0x3. DeepseekV2AttentionMLA 详解

由于这里的代码比较长,这里就只从流程出发,尽量少展示代码。先把DeepSeek MLA的公式截图到这里:

0x3.1 权重介绍

首先汇总一下init中的各个权重介绍,其实和 DeepseekV2Attention 上面的权重基本一致,只不过它对 self.kv_b_proj 做了一个拆分。

具体来说, DeepseekV2AttentionMLA 初始化部分包含:

# 使用两阶段投影:先将hidden_size投影到q_lora_rank,再投影到最终维度
# 第一阶段投影:hidden_size -> q_lora_rank,对应paper公式中的W^DQ
self.q_a_proj = ReplicatedLinear(
    self.hidden_size,
    self.q_lora_rank,
    bias=False,
    quant_config=quant_config,
)
# q_b_proj 大小为 [q_lora_rank, num_heads * q_head_dim] = 
# [q_lora_rank, num_attention_heads * (qk_nope_head_dim + qk_rope_head_dim)]
# 对应上述公式中的W^UQ和W^QR合并后的大矩阵,仅仅只是内存放在一起
self.q_b_proj = ColumnParallelLinear(
    q_lora_rank,
    self.num_heads * self.qk_head_dim,
    bias=False,
    quant_config=quant_config,
)
# KV的第一阶段投影:hidden_size -> kv_lora_rank + qk_rope_head_dim
# 与Q向量类似,KV向量的生成也是先投影到一个低维的 compressed_kv 向量(对应c_t^{KV}=w^{DKV}h_t)
# 再升维展开。具体的代码涉及 kv_a_proj_with_mqa 和 kv_b_proj 两个参数矩阵。
# 其中 kv_a_proj_with_mqa 大小为 [hidden_size, kv_lora_rank + qk_rope_head_dim]
self.kv_a_proj_with_mqa = ReplicatedLinear(
    self.hidden_size,
    self.kv_lora_rank + self.qk_rope_head_dim,
    bias=False,
    quant_config=quant_config,
    FIXME: quick fix for skip quantization
    prefix=f"self_attn.kv_a_proj_with_mqa",
)
# KV的第二阶段投影:kv_lora_rank -> num_heads * (qk_nope_head_dim + v_head_dim)
# kv_b_proj 大小为 [kv_lora_rank, num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim)] 
# 对应 paper 公式中的W^{UK}和W^{UV}。
# 由于 W^{UK} 只涉及 non rope 的部分所以维度中把 qk_rope_head_dim 去掉了,就是上面的-号。
self.kv_b_proj = ColumnParallelLinear(
    self.kv_lora_rank,
    self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
    bias=False,
    quant_config=quant_config,
)
# 输出投影:将注意力的输出投影回原始隐藏层维度
self.o_proj = RowParallelLinear(
    self.num_heads * self.v_head_dim,
    self.hidden_size,
    bias=False,
    quant_config=quant_config,
)

接着,初始化过程中还有两个 self.w_kc,self.w_vc ,它们分别对应了将 self.kv_b_proj 拆分后的 。拆分的代码如下:

w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
                    0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
                ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(12).contiguous().transpose(12)
self_attn.w_vc = w_vc.contiguous().transpose(12)

我们来分析一下这里的shape变化,先确定一下DeepSeek R1下相关的超参数: self.qk_nope_head_dim = 128 self.v_head_dim = 128 self.kv_lora_rank = 512 self.num_heads = 128 ,w 的形状为 [32768, 512] ,即 [128*(128+128), 512]

w.unflatten(0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)) 这一步将 w 的第一个维度 32768 重新组织为两个维度 [-1, 256] ,其中 256 = 128 + 128。这里的 -1 会自动计算为 32768 / 256 = 128 ,所以 unflatten 后的形状为 [128, 256, 512] .split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) 这一步沿着第二个维度(索引为1)将张量分割成两部分: w_kc 的形状为







请到「今天看啥」查看全文