真实面经题目 · 原创解析

如何手写 Multi-Head Self-Attention,Q/K/V 投影、分头、mask 和输出拼接如何实现?

这题考 Transformer 注意力层的可实现细节。好的回答不能只背公式,而要讲清输入输出形状、Q/K/V 一次投影或三次投影、head 维度拆分、scaled dot-product attention、padding/causal mask 广播、softmax/dropout、head 合并、输出投影以及常见数值和 shape bug。

60 秒回答模板

我会按张量形状来写。假设输入 x 形状是 [B, T, C],head 数是 H,每个 head 的维度 D=C/H,要求 C 能被 H 整除。第一步用线性层得到 Q、K、V,可以分别用三个 Linear,也可以一次 Linear 到 3C 后再 split。第二步把 Q/K/V 从 [B, T, C] reshape 成 [B, T, H, D],再 transpose 成 [B, H, T, D],这样每个 head 都可以独立算 attention。第三步计算 scores = Q @ K.transpose(-2, -1) / sqrt(D),形状是 [B, H, T, T]。如果是 Decoder 自回归,要加 causal mask,禁止当前位置看未来 token;如果 batch 里有 padding,还要加 padding mask,禁止关注 pad 位置。mask 通常在 softmax 前把非法位置加一个很大的负数。第四步对最后一维做 softmax 得到 attention weights,再乘 V,得到 [B, H, T, D]。第五步把它 transpose 回 [B, T, H, D],contiguous 后 view 成 [B, T, C],最后过一个输出投影 W_O,得到和输入同形状的输出,方便残差连接。实现时要注意 mask 的维度广播、fp16 下负无穷的数值稳定、dropout 位置、训练和推理模式,以及输入输出 shape 的单元测试。 面试里我会直接给出一个最小 PyTorch 风格实现:用一个 qkv 线性层生成三份张量,reshape/transpose 成 [B,H,T,D],在 scores 上叠加 causal mask 和 key padding mask,再 softmax、乘 V、合并 head 和 out_proj。代码里我会明确 padding mask 从 [B,T] 扩到 [B,1,1,T],causal mask 从 [T,T] 扩到 [1,1,T,T],这样面试官能看到我不只是会背公式,而是真的能处理广播和边界。

考点 形状主线
难度 真实面经题
回答目标 让候选人能按张量形状手写一个可用的 Multi-Head Self-Attention,并解释每一步的数学意义、mask 处理、数值稳定和常见实现陷阱。

深入解析

01

先固定接口和形状

常见 batch-first 实现里,输入 x 是 [B, T, C],输出也保持 [B, T, C]。H 是 head 数,D 是每头维度,满足 C = H * D。这个约束很重要,因为后续 split heads 和 concat heads 都依赖它。若是 cross-attention,Q 来自目标序列,K/V 来自源序列;本题是 self-attention,所以 Q/K/V 都来自同一个 x。

02

Q/K/V 投影可以合并实现

概念上有三组权重 W_Q、W_K、W_V,分别把 x 投成 Q、K、V。工程上常用一个 Linear(C, 3C) 一次得到 qkv,再沿最后一维切成三份,减少 kernel 调用和代码复杂度。无论三次投影还是一次投影,最终 Q/K/V 的逻辑形状都应是 [B, T, C]。

03

分头是 reshape 加 transpose

把 [B, T, C] 改成 [B, T, H, D],再转成 [B, H, T, D],这样矩阵乘法可以把 batch 和 head 都视作并行维度。这里容易出错的是 transpose 后内存不连续,后面如果用 view 合并,需要先 contiguous,或者使用 reshape 但仍要理解实际内存布局。

04

缩放点积注意力是核心计算

每个 head 计算 QK^T,得到 [B, H, T, T] 的注意力打分,再除以 sqrt(D) 控制 logits 尺度。然后对 key 序列维做 softmax,让每个 query 位置得到一组对所有 key 位置的权重。最后权重乘 V,得到每个位置聚合后的 value 表示。

05

mask 必须在 softmax 前处理

causal mask 用来阻止自回归模型看未来,通常形状可广播到 [1, 1, T, T]。padding mask 用来阻止关注填充 token,通常从 [B, T] 扩展到 [B, 1, 1, T]。二者可以合并后在 scores 上对非法位置加大负数。不能在 softmax 后简单把权重置零而不重新归一化,否则概率分布会错。

06

softmax、dropout 和数值稳定

softmax 应该沿最后一维,也就是 key 长度维计算。训练时可对 attention weights 做 dropout,也可在输出投影后做 dropout。fp16/bf16 下如果直接使用真正的负无穷,有时会产生 NaN,实际实现常用 dtype 可表达范围内的大负数,并确保整行都被 mask 的情况被上游避免或特殊处理。

07

合并 head 后还要输出投影

attention 输出形状是 [B, H, T, D],先 transpose 回 [B, T, H, D],再合并成 [B, T, C]。concat 本身只是把多个 head 的结果并排放在通道维,真正跨 head 的线性融合由 W_O 完成。输出投影还能保证维度回到 C,便于残差连接和后续 FFN。

08

验证要覆盖 shape、mask 和梯度

手写实现至少要测输出 shape 是否等于输入 shape,mask 后未来 token 是否无法影响当前输出,padding token 是否不被关注,H=1 与多头分支是否都能跑通,训练模式 dropout 是否生效,反向传播是否有梯度。复杂度上,标准全量 self-attention 的时间和注意力矩阵显存都是 O(B * H * T^2)。

python

PyTorch 风格 Multi-Head Self-Attention 最小实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None, causal=False):
        # x: [B, T, C]
        # key_padding_mask: [B, T], True means this key position is padding.
        bsz, seq_len, d_model = x.shape

        qkv = self.qkv_proj(x)  # [B, T, 3C]
        q, k, v = qkv.chunk(3, dim=-1)

        def split_heads(tensor):
            # [B, T, C] -> [B, H, T, D]
            return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # scores: [B, H, T, T]

        if causal:
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
                diagonal=1,
            )
            scores = scores.masked_fill(causal_mask[None, None, :, :], torch.finfo(scores.dtype).min)

        if key_padding_mask is not None:
            # [B, T] -> [B, 1, 1, T], broadcast over heads and query positions.
            scores = scores.masked_fill(key_padding_mask[:, None, None, :], torch.finfo(scores.dtype).min)

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        context = torch.matmul(attn, v)  # [B, H, T, D]
        context = context.transpose(1, 2).contiguous().view(bsz, seq_len, d_model)
        return self.out_proj(context)  # [B, T, C]

易错点

  • 只写 attention 公式,不说明 [B,T,C] 到 [B,H,T,D] 的 reshape 和 transpose。
  • 把 head 维度和序列维度混淆,导致 softmax 沿错维度计算。
  • 忘记除以 sqrt(D),训练时注意力 logits 过大、softmax 过尖。
  • 在 softmax 后才处理 mask,或者 padding mask 广播维度写错。
  • 合并 head 时忘记 contiguous 或错误使用 view,造成张量布局问题。
  • concat 后漏掉输出投影 W_O,导致跨 head 信息没有可学习融合或维度不匹配。
  • 把 self-attention 和 cross-attention 混在一起,Q/K/V 来源说不清。
  • 没有测试 causal 信息泄漏,只看代码能跑通就认为实现正确。

面试官追问

为什么 QK^T 要除以 sqrt(D)?

D 越大,随机向量点积的方差越大,softmax 输入容易变得很大,从而让分布过尖、梯度变小。除以 sqrt(D) 可以把打分尺度拉回更稳定的范围。

causal mask 和 padding mask 有什么区别?

causal mask 是序列位置约束,防止当前位置看未来;padding mask 是样本长度约束,防止真实 token 关注 pad token。Decoder 训练经常两者都需要。

为什么 transpose 后常常要 contiguous?

transpose 改变的是张量 stride,内存不一定连续。后续如果用 view 合并维度,可能报错或得到错误布局,因此通常先 contiguous,或者用能处理非连续张量的 reshape 并理解其代价。

多头注意力的参数量会随着 head 数线性增加吗?

在 C 固定且 D=C/H 的常见设置下,QKV 和输出投影总参数量仍大致是 C 到 C 的几组线性变换,和 head 数不直接线性增加。head 数主要影响每头维度、中间 attention 形状和并行实现效率。

如何确认 causal mask 真的生效?

可以构造两个输入,它们在当前位置以前完全相同、未来 token 不同;如果当前及以前位置的输出不变,说明未来信息没有泄漏。也可以检查 attention weights 的上三角是否为零。