真实面经题目 · 原创解析
给定输入 shape 为 (b, s, h),如何估算 Multi-Head Attention 的计算量?
这题考给定输入 shape 为 (b, s, h) 时如何估算 MHA 计算量,回答重点是 QKV 投影、注意力矩阵、加权求和和输出投影,其中长序列瓶颈来自 O(b s^2 h)。
真实面经题目 · 原创解析
这题考给定输入 shape 为 (b, s, h) 时如何估算 MHA 计算量,回答重点是 QKV 投影、注意力矩阵、加权求和和输出投影,其中长序列瓶颈来自 O(b s^2 h)。
假设输入 X 的 shape 是 (b, s, h),h 是 hidden size,多头数是 n,单头维度 d=h/n。MHA 的主要计算可以分四块。第一块是 Q、K、V 三个线性投影,每个大约是 b*s*h*h,三者合计 3bsh^2。第二块是每个头计算 QK^T,单头是 b*s*s*d,n 个头合计 b*s^2*h。第三块是 attention 权重乘 V,同样约 b*s^2*h。第四块是输出投影,约 b*s*h*h。因此总量量级可以写成 4bsh^2 + 2bs^2h,忽略常数后是 O(bsh^2 + bs^2h)。当序列长度 s 很长时,s^2h 项会成为瓶颈;当 hidden size 很大而序列不长时,线性投影的 sh^2 也很重。回答时要说明是否按 multiply-add 算 1 次还是 2 FLOPs,常数会变,但复杂度判断不变。
输入是 b 个样本,每个样本 s 个 token,每个 token 隐藏维度 h。多头注意力把 h 拆成 n 个 head,每个 head 维度 d=h/n。估算 FLOPs 时通常先按矩阵乘法量级算,再说明是否把乘加记为 1 或 2。
每个 token 要从 h 维映射到 Q、K、V 的 h 维表示。一个线性投影的矩阵乘法量级是 b*s*h*h,三个投影就是 3bsh^2。如果实现把 QKV 合成一个大矩阵,本质计算量仍然同阶。
每个 head 上,Q 的形状是 (b, s, d),K 转置后是 (b, d, s),得到 (b, s, s) 的注意力分数,计算量约 b*s*s*d。n 个 head 加起来就是 b*s^2*h。这个 s*s 矩阵是长序列注意力的核心来源。
softmax 后的注意力权重形状是 (b, n, s, s),V 是 (b, n, s, d),相乘得到每个 token 的上下文向量。计算量和 QK 类似,也是 b*s^2*h。softmax、scale 和 mask 本身是 O(b*n*s^2) 级别,通常不是 FLOPs 主导项,但会影响 IO 和 kernel 实现。
多头结果 concat 回 h 维后,还要经过输出线性层,计算量约 b*s*h*h。把 QKV 和输出投影合起来是 4bsh^2,把注意力两次矩阵乘合起来是 2bs^2h。
总量级是 O(bsh^2 + bs^2h)。如果 s 远大或持续增长,bs^2h 会迅速超过线性投影,显存上还会遇到 attention score 和 KV cache 压力。面试里要把计算瓶颈和内存瓶颈都点出来。
如果 h 固定、每头维度 d=h/n,总 FLOPs 量级基本不变;head 数主要影响并行方式、常数和实现效率。
因为每个 token 都要和所有 token 计算相关性,注意力分数矩阵是 s*s,QK 和权重乘 V 都带有这个二次项。
可以提到 softmax、scale、mask 是 O(b*n*s^2) 级别。和 bs^2h 或 bsh^2 相比,单头维度 d 较大时通常不是主导,但实现上会影响 IO。
单步 decode 的 attention-only 部分会和历史 s 个 KV 做注意力,约 O(bsh)。如果把新 token 的 QKV 投影和输出投影也算入完整 MHA,还要加 O(bh^2) 项;生成整段序列时 attention 部分仍会二次累积。