真实面经题目 · 原创解析
多头注意力相比单头注意力有什么优势,各个 head 的输出如何拼接并通过输出投影融合?
这题考多头注意力的表示机制和实现细节,重点是说明多个 head 在不同子空间独立做注意力,输出先按特征维拼接,再由输出投影学习跨 head 融合。
真实面经题目 · 原创解析
这题考多头注意力的表示机制和实现细节,重点是说明多个 head 在不同子空间独立做注意力,输出先按特征维拼接,再由输出投影学习跨 head 融合。
多头注意力可以理解为把同一个 token 表示投影到多个较低维的查询、键、值子空间里,让每个 head 独立计算一次 scaled dot-product attention。相比单头注意力,它的优势不是简单把参数变多,而是让模型能并行关注不同关系:有的 head 可能更偏局部上下文,有的偏长距离依赖,有的偏句法、位置或跨模态对齐;即使这些语义不是人工指定的,多头结构也提供了多个互不完全相同的表示通道,降低单一注意力分布的表达瓶颈。实现上,输入 X 先分别乘以每个 head 的 W_Q、W_K、W_V,得到 Q_h、K_h、V_h;每个 head 计算 softmax(Q_h K_h^T / sqrt(d_h)) V_h,输出形状通常是 batch、序列长度、每头维度 d_h。所有 head 算完后,不是在注意力之前融合,而是把各个 head 的输出沿最后的特征维 concat,得到 batch、序列长度、h 乘 d_h 的张量,通常这个维度等于 d_model。随后再乘一个输出投影 W_O,把拼接后的多路特征映射回 d_model。这个 W_O 很关键,它不是形式上的 reshape,而是可学习地混合不同 head 的信息,让后续残差连接和 FFN 看到统一维度的表示。需要补充的是,多头也有代价:注意力矩阵的显存和计算仍随序列长度平方增长,head 太多会让每个 head 维度过小,可能出现 head 冗余或退化。因此回答时最好同时讲清楚机制、融合位置、张量形状和 tradeoff。
单头注意力对每个 query 只产生一套注意力分布,所有信息都通过同一个相似度空间和同一组 value 加权结果表达。如果一个 token 同时需要关注局部搭配、远距离依赖、位置结构和语义对齐,单一分布容易把不同关系压在一起。多头结构通过多个独立投影矩阵,把注意力拆到多个子空间中并行建模,提升表示多样性。
输入表示 X 会被投影成多组 Q、K、V。第 h 个 head 用自己的 Q_h、K_h、V_h 计算 softmax(Q_h K_h^T / sqrt(d_h)) V_h,其中 d_h 通常是 d_model 除以 head 数。这样做的好处是总输出维度不变时,每个 head 的矩阵计算更小,但多个 head 可以学习不同的打分空间、对齐模式和 value 组合方式。
各个 head 在注意力阶段是并行独立的,融合点通常在每个 head 产生输出之后。若单个 head 输出形状是 [batch, seq_len, d_h],h 个 head 会沿最后一维拼接成 [batch, seq_len, h * d_h]。工程实现里常通过 reshape 和 transpose 把 head 维度展开回通道维,这一步本身只是张量重排,不负责学习融合。
拼接后的向量会乘以 W_O,映射回模型主干维度 d_model。W_O 的作用是把不同 head 的结果做可学习线性组合:它可以增强某些 head、抑制冗余 head,也可以把多个 head 捕获的关系组合成新的特征。残差连接要求注意力模块输出维度和输入维度一致,因此输出投影同时承担维度对齐和信息融合。
多头注意力通常在表达能力和稳定性上优于单头,但不是 head 越多越好。在固定 d_model 下,head 数增加会让每个 head 的 d_h 变小,单个 head 的表示容量下降;注意力权重和中间激活也会增加显存压力。长序列场景下,主要瓶颈仍然是注意力矩阵的 O(seq_len^2) 计算和存储,而不是输出投影本身。
实际分析时可以看训练损失、验证指标、注意力熵、head 间相似度、剪枝后性能变化、延迟和显存占用。如果多个 head 的注意力模式高度相似,或者剪掉某些 head 几乎不影响效果,就说明存在冗余。失败边界包括 head collapse、过多小维度 head 导致表达不足,以及在小数据或短序列任务上复杂度收益不明显。
不一定。在常见设置里,h 个 head 的总投影维度仍约等于 d_model,Q、K、V 和 W_O 的总参数量与一个同维度注意力层同阶。差异更多来自并行 head 的中间激活、注意力权重和实现开销。
点积维度越大,QK 的方差越大,softmax 容易过早饱和,梯度变小。除以 sqrt(d_h) 是为了稳定打分尺度,让注意力分布和训练梯度更可控。
不一定。多头提供结构上的多样性机会,但具体 head 是否对应句法、位置或长程依赖,需要通过可视化、剪枝、相似度和任务指标验证,不能强行赋予每个 head 固定含义。
通常在每个 multi-head attention 子层内部就完成拼接和输出投影,然后接残差、归一化和后续 FFN。不是等到所有编码层结束后再统一融合。
要结合 d_model、序列长度、数据规模、硬件并行效率和验证指标。常见原则是让每个 head 保留足够 d_h,同时用消融实验观察质量、延迟、显存和 head 冗余。