# 多头注意力
classMultiHeadAttention(nn.Module):
def__init__(self, head_size, num_head):
super().__init__()
self.sa_head=nn.ModuleList([Head(head_size) for_inrange(num_head)])
self.dropout=nn.Dropout(dropout)
self.proj=nn.Linear(embed_size, embed_size)
defforward(self, x):
x=torch.cat([head(x) forheadinself.sa_head], dim=-1)
x=self.dropout(self.proj(x))
returnx
在注意力头的实现中,系统通过线性变换生成Key、Value和Query,它们的形状均为(B, T, C),其中C为头的大小。
Query与Key进行点积运算,生成形状为(B, T, T)的权重矩阵,表示各token之间的相关性。权重矩阵经过掩码处理转换为下三角矩阵,确保在点积运算过程中每个token仅考虑前面的token(即从1到n),这强制实现了因果关系,使得自回归模型中的token仅使用历史信息预测下一个token。
下图展示了自回归生成在注意力机制中的实现过程:
在自回归生成的每个步骤中,系统均需重新计算已经计算过的Key和Value。例如,在第2步中,K1与第1步生成的K1相同。由于在推理阶段模型参数已固定,相同输入将产生相同输出,因此将这些Key和Value存储在缓存中并在后续步骤中复用是更高效的方法。
下图直观展示了KV缓存的工作机制:
实现KV缓存的主要区别在于:
-
推理时每次仅传入一个新token,而非增量传递所有token;
-
由于Key和Value已缓存,无需重复计算历史token的表示;
-
无需对权重进行掩码处理,因为每次只处理单个Query token,权重矩阵(QK^T)的维度为(B, 1, T)而非(B, T, T)。
缓存机制实现
KV缓存的实现基于形状为(B, T, C)的零张量初始化,其中T为最大处理的token数量(即block_size):
classHead(nn.Module):
def__init__(self, head_size):
super().__init__()
self.head_size=head_size
self.key=nn.Linear(embed_size, head_size, bias=False)
self.query=nn.Linear(embed_size, head_size, bias=False)
self.value=nn.Linear(embed_size, head_size, bias=False)
self.dropout=nn.Dropout(dropout)
self.k_cache=None
self.v_cache=None
self.cache_index=0
defforward(self, x):
B, T, C=x.shape
k=self.key(x)
q=self.query(x)
v=self.value(x)
ifself.k_cacheisNoneorself.v_cacheisNone:
self.k_cache=torch.zeros(B, block_size, self.head_size, device=x.device)
self.v_cache=torch.zeros(B, block_size, self.head_size, device=x.device)
self.cache_index=0
returnout
自回归模型在训练时使用固定的上下文长度,即当前token预测下一个token时可回溯的最大token数量。在本实现中,这个上下文长度由
block_size
参数确定,表示缓存的最大token数量,通过缓存索引进行跟踪:
defforward(self, x):
B, T, C=x.shape
k=self.key(x)
q=self.query(x)
v=self.value(x)
ifself.k_cacheisNoneorself.v_cacheisNone:
self.k_cache=torch.zeros(B, block_size, self.head_size, device=x.device)
self.v_cache=torch.zeros(B, block_size, self.head_size, device=x.device)
self.cache_index=0
ifself.cache_index+T<=block_size:
self.k_cache[:, self.cache_index:self.cache_index+T, :] =k
self.v_cache[:, self.cache_index:self.cache_index+T, :] =v
self.cache_index=min(self.cache_index+T, block_size)
[email protected]_cache.transpose(2, 1)/self.head_size**0.5
wei=F.softmax(wei, dim=2)
wei=self.dropout(wei)
[email protected]_cache
returnout