正文
MHA/GQA/MQA
首先简单介绍一下MQA和GQA。标准的多头注意力就是
MHA(Multi Head Attention)
,在MHA中,KV Heads的数量和Query Heads的数量相同,每个Query Head持有一个独立的KV Head,在Attention中,对单独的KV Head做计算。但是,当模型层数加深和Heads数变多后,QKV Attention的计算和IO都会快速增加。为了缓解这种情况,有学者提出了MQA和GQA。
MQA (Multi Queries Attention): MQA比较极端,只保留一个KV Head,多个Query Heads共享相同的KV Head
。这相当于不同Head的Attention差异,全部都放在了Query上,需要模型仅从不同的Query Heads上就能够关注到输入hidden states不同方面的信息。这样做的好处是,极大地降低了KV Cache的需求,但是会导致模型效果有所下降。
MQA
0x03 层内KV Cache共享: GQA简析
GQA (Group Queries Attention): GQA与MQA不同,而是采取了折中的做法。GQA把Query Heads进行分组,每组Query Heads对应一个KV Head。
比如,把8个Query Heads分成4组,每个Grouped Query Head包含2个Query Heads,一个Grouped Query Head对应一个KV Head,此时总共有4个KV Heads。GQA可以在减少计算量和KV Cache同时确保模型效果不受到大的影响。
GQA
在目前大部分主流训推框架或算法,都已经支持MQA/GQA,比如FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。
GQA/MQA in FlashAttention V2
歪个楼,FlashAttention V1/V2/V3系列原理&图解,推荐阅读我的另一篇文章:
DefTruth:[Attention优化][2w字] 原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3
https://zhuanlan.zhihu.com/p/668888063
GQA最大的作用是节省显存,同时由于LLM推理的一大瓶颈就是memory bound,大模型推理的性能受限于显存带宽。而GPU算力增长是快于显存以及显存带宽的。KV Cache的减少不单可以节省显存,还可以节省需要加载显存所需要的IO时间。我们可以来看一下KV Cache占用的显存占用量,下图来自PageaAttention论文:
KV Cache占用计算方式
KV Cache显存占用的计算方式如下:
1 token KV Cache = 2[K,V] x hidden_size x layers x 2[bytes per FP16] = 4 x H x N bytes
比如对于LLaMA 13B fp16模型,1个token所需要的KV Cache为:4 x 5120 x 40 = 819200 bytes,即 800KB。那么对于L=seq_len为2048 tokens的请求,需要的KV Cache数量为: 4 x 2048 x 5120 x 40 = 2048 x 800KB = 1.6GB。对于长度为L的请求,需要的KV Cache数量为:
KV Cache = 4 x L x H x N bytes # MHA
上述是在MHA下的KV Cache计算公式,最后,再考虑batch_size,那么公式为:
KV Cache = 4 x B x L x H x N bytes # MHA
如果是GQA,假设Q的组数为G,则GQA下需要的KV Cache为:
KV Cache = 4 x B x L x H x N / G bytes # GQA
最后我们用一个表格来直观感受一下,假设以下为某72B模型的配置:
B
L
H
N
G
KV Cache
带宽(Gb/s)
IO(ms)