真实面经题目 · 原创解析

如何用 PyTorch 实现 Grouped Query Attention?当 Q heads 多于 KV heads 时,K/V heads 应如何分组、repeat 或 broadcast,并完成 attention score、mask、softmax 和输出投影?

这道题考察候选人是否理解 Grouped Query Attention 的 head 形状和实现细节。回答要讲清 Q heads 与 KV heads 的分组关系,如何 repeat K/V,如何计算 mask、softmax 和输出投影。

出现于:腾讯 · 算法

60 秒回答模板

GQA 的核心是 Q 的 head 数多于 K/V 的 head 数,多个 query heads 共享同一组 K/V heads。假设 q_heads=32,kv_heads=8,则每 4 个 Q heads 使用同一个 K/V head。实现时输入经过 q_proj、k_proj、v_proj 后 reshape 成 q: [B, qh, Tq, hd],k/v: [B, kvh, Tk, hd]。如果后续 attention kernel 需要 head 数一致,可以把 k/v 在 head 维 repeat_interleave group_size 次,变成 [B, qh, Tk, hd];更高效的实现也可以在 kernel 内按 q_head // group_size 映射到 kv_head,避免真实复制。 之后流程和标准 attention 接近:scores = q @ k.transpose(-2,-1) / sqrt(head_dim),加 causal mask、padding mask 或 attention bias,再 softmax 和 dropout,最后 weights @ v 得到 [B, qh, Tq, hd],transpose/contiguous 后合并为 [B, Tq, qh*hd],接 o_proj。需要注意 q_heads 必须能被 kv_heads 整除,mask 形状要能 broadcast 到 [B, qh, Tq, Tk],KV cache 解码时新 token 只追加 kv_heads 份 K/V,因此 GQA 能比 MHA 降低 cache 显存和带宽,但表达能力通常强于 MQA。

考点 GQA 是多个 Q heads 共享较少的 K/V heads
难度 真实面经题
回答目标 让候选人能写出正确的 GQA PyTorch 骨架,并解释 head 分组、mask、KV cache 和 MHA/MQA 的工程取舍。

深入解析

01

形状约定

先定义 B、Tq、Tk、q_heads、kv_heads 和 head_dim。Q reshape 成多头,K/V 只保留较少 head。q_heads 必须是 kv_heads 的整数倍,否则无法均匀分组。

02

分组映射

group_size = q_heads / kv_heads。第 i 个 query head 使用第 floor(i / group_size) 个 KV head。这个映射可以通过 repeat_interleave 实现,也可以在高性能 kernel 中隐式索引。

03

attention 计算

把 K 转置后与 Q 做矩阵乘,除以 sqrt(head_dim),再叠加 causal mask、padding mask 或业务 attention bias。mask 维度要能 broadcast 到所有 Q heads。

04

softmax 输出

softmax 沿 key 序列维度归一化,得到每个 query token 对历史 token 的权重,再乘 V 得到每个 Q head 的上下文表示。最后合并 head 并过输出投影。

05

KV cache

自回归解码时,GQA 每层只缓存 kv_heads 份 K/V,比 MHA 的 q_heads 份更省显存和带宽。长上下文和高并发服务中,这个收益很明显。

06

工程细节

实现要检查 dtype、device、contiguous、mask 正负无穷、dropout 只在训练启用,以及 repeat 是否造成额外显存。生产中更常用 fused attention 或 paged KV cache kernel。

python

PyTorch GQA 核心实现

import math
import torch

def grouped_query_attention(q, k, v, o_proj, attn_mask=None):
    # q: [B, q_heads, Tq, D]
    # k/v: [B, kv_heads, Tk, D]
    bsz, q_heads, tq, head_dim = q.shape
    _, kv_heads, tk, _ = k.shape
    assert q_heads % kv_heads == 0

    group_size = q_heads // kv_heads
    k = k.repeat_interleave(group_size, dim=1)
    v = v.repeat_interleave(group_size, dim=1)

    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
    if attn_mask is not None:
        scores = scores + attn_mask

    weights = torch.softmax(scores, dim=-1)
    out = torch.matmul(weights, v)
    out = out.transpose(1, 2).contiguous().view(bsz, tq, q_heads * head_dim)
    return o_proj(out)
  • 这份代码展示形状关系,生产实现通常避免 repeat_interleave 的真实复制,并使用支持 GQA 的 fused attention。
  • attn_mask 应能 broadcast 到 [B, q_heads, Tq, Tk],并在 softmax 前加入。

易错点

  • 把 Q head 也减少成 KV head 数,改变了模型 hidden 结构。
  • 忘记 q_heads 必须能整除 kv_heads。
  • repeat 的维度写错,把序列维或 batch 维复制了。
  • softmax 维度错,用在 head 维而不是 key 序列维。
  • mask 形状不能 broadcast,导致不同 batch 或 head 泄漏无效 token。
  • 只写矩阵乘,不讲 KV cache 和 GQA 的推理收益。

面试官追问

GQA 和 MQA、MHA 的区别是什么?

MHA 每个 Q head 都有自己的 K/V head;MQA 所有 Q heads 共享一组 K/V;GQA 介于两者之间,多个 Q heads 一组共享 K/V,通常在质量和推理效率之间折中。

repeat K/V 会不会浪费显存?

教学实现会真实复制,易懂但浪费。高性能实现会在 attention kernel 里用 head 映射读取 kv_heads,或使用支持 GQA 的 fused attention,避免物理复制。

KV cache 中为什么 GQA 更省?

因为每个 token 每层只需要存 kv_heads 份 K/V,而不是 q_heads 份。长上下文下 cache 大小与层数、序列长度、head_dim 和 kv_heads 成正比。

mask 加在哪里?

加在 softmax 前的 score 上。causal mask 屏蔽未来 token,padding mask 屏蔽无效 token,通常用一个很大的负数,让 softmax 后权重接近 0。