01
60 秒回答模板
GQA 的核心是 Q 的 head 数多于 K/V 的 head 数,多个 query heads 共享同一组 K/V heads。假设 q_heads=32,kv_heads=8,则每 4 个 Q heads 使用同一个 K/V head。实现时输入经过 q_proj、k_proj、v_proj 后 reshape 成 q: [B, qh, Tq, hd],k/v: [B, kvh, Tk, hd]。如果后续 attention kernel 需要 head 数一致,可以把 k/v 在 head 维 repeat_interleave group_size 次,变成 [B, qh, Tk, hd];更高效的实现也可以在 kernel 内按 q_head // group_size 映射到 kv_head,避免真实复制。
之后流程和标准 attention 接近:scores = q @ k.transpose(-2,-1) / sqrt(head_dim),加 causal mask、padding mask 或 attention bias,再 softmax 和 dropout,最后 weights @ v 得到 [B, qh, Tq, hd],transpose/contiguous 后合并为 [B, Tq, qh*hd],接 o_proj。需要注意 q_heads 必须能被 kv_heads 整除,mask 形状要能 broadcast 到 [B, qh, Tq, Tk],KV cache 解码时新 token 只追加 kv_heads 份 K/V,因此 GQA 能比 MHA 降低 cache 显存和带宽,但表达能力通常强于 MQA。
考点 GQA 是多个 Q heads 共享较少的 K/V heads
难度 真实面经题
回答目标 让候选人能写出正确的 GQA PyTorch 骨架,并解释 head 分组、mask、KV cache 和 MHA/MQA 的工程取舍。
02
深入解析
01 形状约定
先定义 B、Tq、Tk、q_heads、kv_heads 和 head_dim。Q reshape 成多头,K/V 只保留较少 head。q_heads 必须是 kv_heads 的整数倍,否则无法均匀分组。
02 分组映射
group_size = q_heads / kv_heads。第 i 个 query head 使用第 floor(i / group_size) 个 KV head。这个映射可以通过 repeat_interleave 实现,也可以在高性能 kernel 中隐式索引。
03 attention 计算
把 K 转置后与 Q 做矩阵乘,除以 sqrt(head_dim),再叠加 causal mask、padding mask 或业务 attention bias。mask 维度要能 broadcast 到所有 Q heads。
04 softmax 输出
softmax 沿 key 序列维度归一化,得到每个 query token 对历史 token 的权重,再乘 V 得到每个 Q head 的上下文表示。最后合并 head 并过输出投影。
05 KV cache
自回归解码时,GQA 每层只缓存 kv_heads 份 K/V,比 MHA 的 q_heads 份更省显存和带宽。长上下文和高并发服务中,这个收益很明显。
06 工程细节
实现要检查 dtype、device、contiguous、mask 正负无穷、dropout 只在训练启用,以及 repeat 是否造成额外显存。生产中更常用 fused attention 或 paged KV cache kernel。
python
PyTorch GQA 核心实现
import math
import torch
def grouped_query_attention(q, k, v, o_proj, attn_mask=None):
# q: [B, q_heads, Tq, D]
# k/v: [B, kv_heads, Tk, D]
bsz, q_heads, tq, head_dim = q.shape
_, kv_heads, tk, _ = k.shape
assert q_heads % kv_heads == 0
group_size = q_heads // kv_heads
k = k.repeat_interleave(group_size, dim=1)
v = v.repeat_interleave(group_size, dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
if attn_mask is not None:
scores = scores + attn_mask
weights = torch.softmax(scores, dim=-1)
out = torch.matmul(weights, v)
out = out.transpose(1, 2).contiguous().view(bsz, tq, q_heads * head_dim)
return o_proj(out)
- 这份代码展示形状关系,生产实现通常避免 repeat_interleave 的真实复制,并使用支持 GQA 的 fused attention。
- attn_mask 应能 broadcast 到 [B, q_heads, Tq, Tk],并在 softmax 前加入。
03
易错点
- 把 Q head 也减少成 KV head 数,改变了模型 hidden 结构。
- 忘记 q_heads 必须能整除 kv_heads。
- repeat 的维度写错,把序列维或 batch 维复制了。
- softmax 维度错,用在 head 维而不是 key 序列维。
- mask 形状不能 broadcast,导致不同 batch 或 head 泄漏无效 token。
- 只写矩阵乘,不讲 KV cache 和 GQA 的推理收益。
04
面试官追问
GQA 和 MQA、MHA 的区别是什么?
MHA 每个 Q head 都有自己的 K/V head;MQA 所有 Q heads 共享一组 K/V;GQA 介于两者之间,多个 Q heads 一组共享 K/V,通常在质量和推理效率之间折中。
repeat K/V 会不会浪费显存?
教学实现会真实复制,易懂但浪费。高性能实现会在 attention kernel 里用 head 映射读取 kv_heads,或使用支持 GQA 的 fused attention,避免物理复制。
KV cache 中为什么 GQA 更省?
因为每个 token 每层只需要存 kv_heads 份 K/V,而不是 q_heads 份。长上下文下 cache 大小与层数、序列长度、head_dim 和 kv_heads 成正比。
mask 加在哪里?
加在 softmax 前的 score 上。causal mask 屏蔽未来 token,padding mask 屏蔽无效 token,通常用一个很大的负数,让 softmax 后权重接近 0。