专栏名称: 数据派THU
本订阅号是“THU数据派”的姊妹账号,致力于传播大数据价值、培养数据思维。
目录
相关文章推荐
软件定义世界(SDX)  ·  谷歌CEO劈柴震撼预言:2030年AI直逼超 ... ·  2 天前  
数局  ·  解数咨询:2024年保健品行业复盘 ·  2 天前  
51好读  ›  专栏  ›  数据派THU

加速LLM大模型推理,KV缓存技术详解与PyTorch实现

数据派THU  · 公众号  · 大数据  · 2025-05-26 17:00

正文

请到「今天看啥」查看全文


block_size ),C表示嵌入维度。

多头注意力机制的核心思想是将嵌入空间划分为多个头,每个头独立计算注意力权重。对于嵌入维度C=128且头数量为4的情况,每个头的维度为128/4=32。系统将分别计算这4个大小为32的注意力头,然后将结果拼接成形状为(B, T, 128)的输出张量。

KV缓存的必要性

为理解KV缓存的必要性,首先需要分析注意力机制的计算过程:

 # 多头注意力classMultiHeadAttention(nn.Module):    def__init__(self, head_size, num_head):        super().__init__()        self.sa_head=nn.ModuleList([Head(head_size) for_inrange(num_head)])        self.dropout=nn.Dropout(dropout)        self.proj=nn.Linear(embed_size, embed_size)    defforward(self, x):        x=torch.cat([head(x) forheadinself.sa_head], dim=-1)        x=self.dropout(self.proj(x))         returnx


在注意力头的实现中,系统通过线性变换生成Key、Value和Query,它们的形状均为(B, T, C),其中C为头的大小。

Query与Key进行点积运算,生成形状为(B, T, T)的权重矩阵,表示各token之间的相关性。权重矩阵经过掩码处理转换为下三角矩阵,确保在点积运算过程中每个token仅考虑前面的token(即从1到n),这强制实现了因果关系,使得自回归模型中的token仅使用历史信息预测下一个token。

下图展示了自回归生成在注意力机制中的实现过程:

Image

在自回归生成的每个步骤中,系统均需重新计算已经计算过的Key和Value。例如,在第2步中,K1与第1步生成的K1相同。由于在推理阶段模型参数已固定,相同输入将产生相同输出,因此将这些Key和Value存储在缓存中并在后续步骤中复用是更高效的方法。

下图直观展示了KV缓存的工作机制:

Image

实现KV缓存的主要区别在于:

  1. 推理时每次仅传入一个新token,而非增量传递所有token;

  2. 由于Key和Value已缓存,无需重复计算历史token的表示;

  3. 无需对权重进行掩码处理,因为每次只处理单个Query token,权重矩阵(QK^T)的维度为(B, 1, T)而非(B, T, T)。

缓存机制实现

KV缓存的实现基于形状为(B, T, C)的零张量初始化,其中T为最大处理的token数量(即block_size):

classHead(nn.Module):    def__init__(self, head_size):        super().__init__()        self.head_size=head_size        self.key=nn.Linear(embed_size, head_size, bias=False)        self.query=nn.Linear(embed_size, head_size, bias=False)        self.value=nn.Linear(embed_size, head_size, bias=False)        self.dropout=nn.Dropout(dropout)        self.k_cache=None        self.v_cache=None        self.cache_index=0    defforward(self, x):        B, T, C=x.shape# 形状: B, 1, C        k=self.key(x)        q=self.query(x)        v=self.value(x)        # 如果缓存为空则初始化        ifself.k_cacheisNoneorself.v_cacheisNone:            # 使用固定大小初始化缓存            self.k_cache=torch.zeros(B, block_size, self.head_size, device=x.device)            self.v_cache=torch.zeros(B, block_size, self.head_size, device=x.device)            self.cache_index=0         returnout


自回归模型在训练时使用固定的上下文长度,即当前token预测下一个token时可回溯的最大token数量。在本实现中,这个上下文长度由 block_size 参数确定,表示缓存的最大token数量,通过缓存索引进行跟踪:

defforward(self, x):        B, T, C=x.shape# B, 1, C        k=self.key(x)        q=self.query(x)        v=self.value(x)        # 如果缓存为空则初始化        ifself.k_cacheisNoneorself.v_cacheisNone:            # 使用固定大小初始化缓存            self.k_cache=torch.zeros(B, block_size, self.head_size, device=x.device)            self.v_cache=torch.zeros(B, block_size, self.head_size, device=x.device)            self.cache_index=0        # 原地更新缓存        ifself.cache_index+T<=block_size:            self.k_cache[:, self.cache_index:self.cache_index+T, :] =k            self.v_cache[:, self.cache_index:self.cache_index+T, :] =v        # 注意:鉴于我们一次只传递一个token,T将始终为1,因此上面的操作        # 等效于直接执行self.k_cache[:, self.cache_index, :] = k        # 更新缓存索引        self.cache_index=min(self.cache_index+T, block_size)        # 注意力点积        [email protected]_cache.transpose(21)/self.head_size**0.5        wei=F.softmax(wei, dim=2)    # (B, block_size, block_size)        wei=self.dropout(wei)        [email protected]_cache         returnout

从第一个token开始,系统将Key-Value对存入对应的缓存位置,并递增缓存索引直到达到设定的上限:






请到「今天看啥」查看全文