import torch
import torch.nn as nn
from thop import profile
class MultiQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
"""
Multi-Query Attention 的实现。
Args:
hidden_size (int): 输入特征的维度,也即 hidden_state 的最后一维。
num_heads (int): 注意力头的数量。
dropout (float): dropout 的概率,默认为 0.0。
"""
super(MultiQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, self.head_dim)
self.value = nn.Linear(hidden_size, self.head_dim)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
"""
前向传播函数。
Args:
hidden_state (torch.Tensor): 输入的 hidden_state,形状为 [batch_size, seq_len, hidden_size]。
attention_mask (torch.Tensor, optional): 注意力掩码,用于屏蔽某些位置,形状为 [batch_size, seq_len]。默认为 None。
Returns:
torch.Tensor: 注意力输出,形状为 [batch_size, seq_len, hidden_size]。
"""
batch_size, seq_len, _ = hidden_state.size()
query = self.query(hidden_state)
key = self.key(hidden_state)
value = self.value(hidden_state)
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
value = value.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, value)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
output = self.out_projection(context)
return output
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
mqa = MultiQueryAttention(hidden_size, num_heads)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0
output = mqa(hidden_state, attention_mask)
print("输出形状:", output.shape)