真实面经题目 · 原创解析
如何手写 Multi-Head Self-Attention,Q/K/V 投影、分头、mask 和输出拼接如何实现?
这题考 Transformer 注意力层的可实现细节。好的回答不能只背公式,而要讲清输入输出形状、Q/K/V 一次投影或三次投影、head 维度拆分、scaled dot-product attention、padding/causal mask 广播、softmax/dropout、head 合并、输出投影以及常见数值和 shape bug。
真实面经题目 · 原创解析
这题考 Transformer 注意力层的可实现细节。好的回答不能只背公式,而要讲清输入输出形状、Q/K/V 一次投影或三次投影、head 维度拆分、scaled dot-product attention、padding/causal mask 广播、softmax/dropout、head 合并、输出投影以及常见数值和 shape bug。
我会按张量形状来写。假设输入 x 形状是 [B, T, C],head 数是 H,每个 head 的维度 D=C/H,要求 C 能被 H 整除。第一步用线性层得到 Q、K、V,可以分别用三个 Linear,也可以一次 Linear 到 3C 后再 split。第二步把 Q/K/V 从 [B, T, C] reshape 成 [B, T, H, D],再 transpose 成 [B, H, T, D],这样每个 head 都可以独立算 attention。第三步计算 scores = Q @ K.transpose(-2, -1) / sqrt(D),形状是 [B, H, T, T]。如果是 Decoder 自回归,要加 causal mask,禁止当前位置看未来 token;如果 batch 里有 padding,还要加 padding mask,禁止关注 pad 位置。mask 通常在 softmax 前把非法位置加一个很大的负数。第四步对最后一维做 softmax 得到 attention weights,再乘 V,得到 [B, H, T, D]。第五步把它 transpose 回 [B, T, H, D],contiguous 后 view 成 [B, T, C],最后过一个输出投影 W_O,得到和输入同形状的输出,方便残差连接。实现时要注意 mask 的维度广播、fp16 下负无穷的数值稳定、dropout 位置、训练和推理模式,以及输入输出 shape 的单元测试。 面试里我会直接给出一个最小 PyTorch 风格实现:用一个 qkv 线性层生成三份张量,reshape/transpose 成 [B,H,T,D],在 scores 上叠加 causal mask 和 key padding mask,再 softmax、乘 V、合并 head 和 out_proj。代码里我会明确 padding mask 从 [B,T] 扩到 [B,1,1,T],causal mask 从 [T,T] 扩到 [1,1,T,T],这样面试官能看到我不只是会背公式,而是真的能处理广播和边界。
常见 batch-first 实现里,输入 x 是 [B, T, C],输出也保持 [B, T, C]。H 是 head 数,D 是每头维度,满足 C = H * D。这个约束很重要,因为后续 split heads 和 concat heads 都依赖它。若是 cross-attention,Q 来自目标序列,K/V 来自源序列;本题是 self-attention,所以 Q/K/V 都来自同一个 x。
概念上有三组权重 W_Q、W_K、W_V,分别把 x 投成 Q、K、V。工程上常用一个 Linear(C, 3C) 一次得到 qkv,再沿最后一维切成三份,减少 kernel 调用和代码复杂度。无论三次投影还是一次投影,最终 Q/K/V 的逻辑形状都应是 [B, T, C]。
把 [B, T, C] 改成 [B, T, H, D],再转成 [B, H, T, D],这样矩阵乘法可以把 batch 和 head 都视作并行维度。这里容易出错的是 transpose 后内存不连续,后面如果用 view 合并,需要先 contiguous,或者使用 reshape 但仍要理解实际内存布局。
每个 head 计算 QK^T,得到 [B, H, T, T] 的注意力打分,再除以 sqrt(D) 控制 logits 尺度。然后对 key 序列维做 softmax,让每个 query 位置得到一组对所有 key 位置的权重。最后权重乘 V,得到每个位置聚合后的 value 表示。
causal mask 用来阻止自回归模型看未来,通常形状可广播到 [1, 1, T, T]。padding mask 用来阻止关注填充 token,通常从 [B, T] 扩展到 [B, 1, 1, T]。二者可以合并后在 scores 上对非法位置加大负数。不能在 softmax 后简单把权重置零而不重新归一化,否则概率分布会错。
softmax 应该沿最后一维,也就是 key 长度维计算。训练时可对 attention weights 做 dropout,也可在输出投影后做 dropout。fp16/bf16 下如果直接使用真正的负无穷,有时会产生 NaN,实际实现常用 dtype 可表达范围内的大负数,并确保整行都被 mask 的情况被上游避免或特殊处理。
attention 输出形状是 [B, H, T, D],先 transpose 回 [B, T, H, D],再合并成 [B, T, C]。concat 本身只是把多个 head 的结果并排放在通道维,真正跨 head 的线性融合由 W_O 完成。输出投影还能保证维度回到 C,便于残差连接和后续 FFN。
手写实现至少要测输出 shape 是否等于输入 shape,mask 后未来 token 是否无法影响当前输出,padding token 是否不被关注,H=1 与多头分支是否都能跑通,训练模式 dropout 是否生效,反向传播是否有梯度。复杂度上,标准全量 self-attention 的时间和注意力矩阵显存都是 O(B * H * T^2)。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, key_padding_mask=None, causal=False):
# x: [B, T, C]
# key_padding_mask: [B, T], True means this key position is padding.
bsz, seq_len, d_model = x.shape
qkv = self.qkv_proj(x) # [B, T, 3C]
q, k, v = qkv.chunk(3, dim=-1)
def split_heads(tensor):
# [B, T, C] -> [B, H, T, D]
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
q = split_heads(q)
k = split_heads(k)
v = split_heads(v)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# scores: [B, H, T, T]
if causal:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(causal_mask[None, None, :, :], torch.finfo(scores.dtype).min)
if key_padding_mask is not None:
# [B, T] -> [B, 1, 1, T], broadcast over heads and query positions.
scores = scores.masked_fill(key_padding_mask[:, None, None, :], torch.finfo(scores.dtype).min)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
context = torch.matmul(attn, v) # [B, H, T, D]
context = context.transpose(1, 2).contiguous().view(bsz, seq_len, d_model)
return self.out_proj(context) # [B, T, C]
D 越大,随机向量点积的方差越大,softmax 输入容易变得很大,从而让分布过尖、梯度变小。除以 sqrt(D) 可以把打分尺度拉回更稳定的范围。
causal mask 是序列位置约束,防止当前位置看未来;padding mask 是样本长度约束,防止真实 token 关注 pad token。Decoder 训练经常两者都需要。
transpose 改变的是张量 stride,内存不一定连续。后续如果用 view 合并维度,可能报错或得到错误布局,因此通常先 contiguous,或者用能处理非连续张量的 reshape 并理解其代价。
在 C 固定且 D=C/H 的常见设置下,QKV 和输出投影总参数量仍大致是 C 到 C 的几组线性变换,和 head 数不直接线性增加。head 数主要影响每头维度、中间 attention 形状和并行实现效率。
可以构造两个输入,它们在当前位置以前完全相同、未来 token 不同;如果当前及以前位置的输出不变,说明未来信息没有泄漏。也可以检查 attention weights 的上三角是否为零。