60 秒回答模板

假设输入 X 的 shape 是 (b, s, h),h 是 hidden size,多头数是 n,单头维度 d=h/n。MHA 的主要计算可以分四块。第一块是 Q、K、V 三个线性投影,每个大约是 b*s*h*h,三者合计 3bsh^2。第二块是每个头计算 QK^T,单头是 b*s*s*d,n 个头合计 b*s^2*h。第三块是 attention 权重乘 V,同样约 b*s^2*h。第四块是输出投影,约 b*s*h*h。因此总量量级可以写成 4bsh^2 + 2bs^2h,忽略常数后是 O(bsh^2 + bs^2h)。当序列长度 s 很长时,s^2h 项会成为瓶颈;当 hidden size 很大而序列不长时,线性投影的 sh^2 也很重。回答时要说明是否按 multiply-add 算 1 次还是 2 FLOPs,常数会变,但复杂度判断不变。

考点 QKV 投影
难度 真实面经题
回答目标 讲清复杂度估算和长序列瓶颈

深入解析

01

先约定 shape 和多头关系

输入是 b 个样本,每个样本 s 个 token,每个 token 隐藏维度 h。多头注意力把 h 拆成 n 个 head,每个 head 维度 d=h/n。估算 FLOPs 时通常先按矩阵乘法量级算,再说明是否把乘加记为 1 或 2。

02

QKV 投影是 3bsh^2

每个 token 要从 h 维映射到 Q、K、V 的 h 维表示。一个线性投影的矩阵乘法量级是 b*s*h*h,三个投影就是 3bsh^2。如果实现把 QKV 合成一个大矩阵,本质计算量仍然同阶。

03

QK 矩阵带来 bs^2h

每个 head 上,Q 的形状是 (b, s, d),K 转置后是 (b, d, s),得到 (b, s, s) 的注意力分数,计算量约 b*s*s*d。n 个 head 加起来就是 b*s^2*h。这个 s*s 矩阵是长序列注意力的核心来源。

04

权重乘 V 也是 bs^2h

softmax 后的注意力权重形状是 (b, n, s, s),V 是 (b, n, s, d),相乘得到每个 token 的上下文向量。计算量和 QK 类似,也是 b*s^2*h。softmax、scale 和 mask 本身是 O(b*n*s^2) 级别,通常不是 FLOPs 主导项,但会影响 IO 和 kernel 实现。

05

输出投影补上 bsh^2

多头结果 concat 回 h 维后,还要经过输出线性层,计算量约 b*s*h*h。把 QKV 和输出投影合起来是 4bsh^2,把注意力两次矩阵乘合起来是 2bs^2h。

06

长序列瓶颈来自二次项

总量级是 O(bsh^2 + bs^2h)。如果 s 远大或持续增长,bs^2h 会迅速超过线性投影,显存上还会遇到 attention score 和 KV cache 压力。面试里要把计算瓶颈和内存瓶颈都点出来。

易错点

  • 只写 O(s^2),没有带上 batch、hidden size 和线性投影项。
  • 把 head 数 n 直接乘到总复杂度上,忘记单头维度 d=h/n。
  • 漏掉 V 加权求和或输出投影,只算了 QK^T。
  • 没有说明 multiply-add 计数口径,导致常数和别人不一致时无法解释。
  • 把训练全序列注意力和推理单步 decode 混为一谈。
  • 只讲 FLOPs,不提长序列下 attention score、KV cache 和 IO 也会成为瓶颈。

面试官追问

head 数变多会不会改变总 FLOPs?

如果 h 固定、每头维度 d=h/n,总 FLOPs 量级基本不变;head 数主要影响并行方式、常数和实现效率。

为什么长序列注意力是 O(s^2)?

因为每个 token 都要和所有 token 计算相关性,注意力分数矩阵是 s*s,QK 和权重乘 V 都带有这个二次项。

FLOPs 估算里要不要算 softmax?

可以提到 softmax、scale、mask 是 O(b*n*s^2) 级别。和 bs^2h 或 bsh^2 相比,单头维度 d 较大时通常不是主导,但实现上会影响 IO。

推理 decode 阶段复杂度还一样吗?

单步 decode 的 attention-only 部分会和历史 s 个 KV 做注意力,约 O(bsh)。如果把新 token 的 QKV 投影和输出投影也算入完整 MHA,还要加 O(bh^2) 项;生成整段序列时 attention 部分仍会二次累积。