正文
,这里的
是合并了所有head对应的param weight后的表达。
我们现在的总目标是节省K cache,当你再次端详上面这幅图时,一个idea在你的头脑中出现:
-
当前我要存的K cache是4个k_head(图中深绿色框),
但如果我能从这4个k_head中抽取出1份共有的信息,然后在做attn计算时,每个head都用这1份共有的信息做计算,那么我也只需存这1份共有信息作为K cache了
。这样我就把K cache从原来num_heads = 4变成num_heads = 1,这不就能节省K cache了吗?
-
但是等等,
现在共有的k_head信息是抽取出来了,那么相异的k_head信息呢?(简单来说,就是由
不同head部分学习到的相异信息)
。我们当然是希望k_head间相异的信息也能保留下来,那么该把它们保留至哪里呢?当你回顾attn_weights的计算公式时,一个想法在你脑中闪现:
q部分不是也有heads吗!我可以把每个k_head独有的信息转移到对应的q_head上吗!写成公式解释就是
:
-
原来
,括号表示运算顺序,即先各自算2个括号内的,再做 * 计算
-
-
也就是说,这里我们通过矩阵乘法的交换律,巧妙地把1个token上k_heads独有的信息转移到了对应的q_head上来,这样1个token上k_heads间共享的相同信息就能被我们当作K cache存储下来。
(在这里,你可以
抽象地
把