专栏名称: Datawhale
一个专注于AI领域的开源组织,汇聚了众多顶尖院校和知名企业的优秀学习者,聚集了一群有开源精神和探索精神的团队成员。愿景-for the learner,和学习者一起成长。
目录
相关文章推荐
51好读  ›  专栏  ›  Datawhale

手撕大模型Attention:MLA、MHA、MQA与GQA(含实现代码)

Datawhale  · 公众号  ·  · 2025-05-20 17:18

正文

请到「今天看啥」查看全文


通过以下线性变换生成 Query、共享的 Key 和 Value:

其中 是每个头独立的查询参数,而 是所有头共享的键和值参数( )。与 MHA 不同,MQA 的 Key 和 Value 投影参数在头间共享,参数量减少 倍。

第二阶段:多头注意力计算

Query 被拆分为 个头(形状调整为 ),而 Key 和 Value 通过 unsqueeze expand 扩展为 ,实现所有头共享相同的 Key/Value。注意力权重计算为:


其中 是第 个头的 Query,而 为共享的全局矩阵。此步骤保留了多头的并行性,但减少了 Key/Value 的冗余计算。

第三阶段: 输出投影与优化

将所有头的输出拼接并通过输出投影层:

由于 Key/Value 共享,MQA 的总参数量为 ,远小于 MHA 的 (当 时)。这种设计在保持序列建模能力的同时,降低了显存占用和计算延迟,适合大规模模型部署。

代码实现如下:

import torchimport torch.nn as nnfrom thop import profile 
class MultiQueryAttention(nn.Module):    def __init__(self, hidden_size, num_heads, dropout=0.0):        """        Multi-Query Attention 的实现。        Args:            hidden_size (int): 输入特征的维度,也即 hidden_state 的最后一维。            num_heads (int): 注意力头的数量。            dropout (float): dropout 的概率,默认为 0.0。        """        super(MultiQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0"hidden_size 必须能被 num_heads 整除"
        self.hidden_size = hidden_size        self.num_heads = num_heads        self.head_dim = hidden_size // num_heads  # 每个头的维度
        # 定义线性变换层,用于生成 Q, K, V        self.query = nn.Linear(hidden_size, hidden_size)  # 每个头独立的 Query        self.key = nn.Linear(hidden_size, self.head_dim)  # 所有头共享的 Key        self.value = nn.Linear(hidden_size, self.head_dim)  # 所有头共享的 Value
        self.dropout = nn.Dropout(dropout)        self.out_projection = nn.Linear(hidden_size, hidden_size)
    def forward(self, hidden_state, attention_mask=None):        """        前向传播函数。        Args:            hidden_state (torch.Tensor): 输入的 hidden_state,形状为 [batch_size, seq_len, hidden_size]。            attention_mask (torch.Tensor, optional): 注意力掩码,用于屏蔽某些位置,形状为 [batch_size, seq_len]。默认为 None。        Returns:            torch.Tensor: 注意力输出,形状为 [batch_size, seq_len, hidden_size]。        """        batch_size, seq_len, _ = hidden_state.size()
        # 1. 通过线性层得到 Q, K, V        query = self.query(hidden_state)  # [batch_size, seq_len, hidden_size]        key = self.key(hidden_state)      # [batch_size, seq_len, head_dim]        value = self.value(hidden_state)  # [batch_size, seq_len, head_dim]
        # 2. 将 Q 拆分为多头        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(12)  # [batch_size, num_heads, seq_len, head_dim]
        # 3. 扩展 K 和 V 到 num_heads 维度(所有头共享相同的 K/V)        key = key.unsqueeze(1).expand(-1, self.num_heads, -1, -1)  # [batch_size, num_heads, seq_len, head_dim]        value = value.unsqueeze(1).expand(-1, self.num_heads, -1, -1)  # [batch_size, num_heads, seq_len, head_dim]
        # 4. 计算注意力权重        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [batch_size, num_heads, seq_len, seq_len]        # 应用 attention mask        if attention_mask is not None:            attention_weights = attention_weights.masked_fill(attention_mask[:, NoneNone, :] == 0float('-inf'))
        attention_weights = torch.softmax(attention_weights, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]        attention_weights = self.dropout(attention_weights)
        # 5. 计算上下文向量        context = torch.matmul(attention_weights, value)  # [batch_size, num_heads, seq_len, head_dim]
        # 6. 将多头合并        context = context.transpose(12).contiguous().view(batch_size, seq_len, self.hidden_size)  # [batch_size, seq_len, hidden_size]
        # 7. 通过输出线性层        output = self.out_projection(context)  # [batch_size, seq_len, hidden_size]
        return output
if __name__ == '__main__':    # 示例    batch_size = 2    seq_len = 10    hidden_size = 256    num_heads = 8
    # 创建一个 MQA 实例    mqa = MultiQueryAttention(hidden_size, num_heads)
    # 创建一个随机的 hidden_state    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    # 创建一个 attention mask (可选)    attention_mask = torch.ones(batch_size, seq_len)    attention_mask[:, 5:] = 0  # 屏蔽掉每个 batch 中 seq_len 的后 5 个位置
    # 通过 MQA 层    output = mqa(hidden_state, attention_mask)
    # 打印输出形状    print("输出形状:", output.shape)  # torch.Size([2, 10, 256])


3

分组查询注意力机制(Grouped Query Attention,GQA)

Grouped Query Attention (GQA) 是对多头注意力(MHA)和多查询注意力(MQA)的折中优化方案。其核心思想是将查询头(Query Heads)划分为多个组(Group),每组内的查询头共享一组键(Key)和值(Value),从而在保留多头并行性的同时减少参数量和计算复杂度。GQA 在参数效率与模型性能之间取得了平衡,适用于大规模模型的高效部署。

第一阶段:分组线性变换

输入序列







请到「今天看啥」查看全文