真实面经题目 · 原创解析
如何用代码实现 Multi-Head Cross-Attention,Q/K/V 的输入维度如何对齐?
这题考手写 Multi-Head Cross-Attention 的维度理解和实现顺序,回答重点是 Q 来自目标序列,K/V 来自条件序列,以及多头拆分、mask 和输出合并。
真实面经题目 · 原创解析
这题考手写 Multi-Head Cross-Attention 的维度理解和实现顺序,回答重点是 Q 来自目标序列,K/V 来自条件序列,以及多头拆分、mask 和输出合并。
Cross-Attention 和 Self-Attention 的主要差别是 Q/K/V 的来源。Self-Attention 里 Q、K、V 都来自同一个序列;Cross-Attention 里 Q 来自当前要更新的目标序列,比如 decoder token 或文本 token,K 和 V 来自外部条件序列,比如 encoder 输出、图像 token 或检索证据。实现时输入可以设成 x: [B, Tq, Dq],context: [B, Tk, Dc]。先分别线性投影得到 q、k、v,并统一到 num_heads * head_dim;再 reshape 成 [B, H, T, Dh];然后计算 scores = q @ k^T / sqrt(Dh),形状是 [B, H, Tq, Tk];如果有 padding mask 或因果约束就在 scores 上加负无穷;softmax 后乘 v 得到 [B, H, Tq, Dh];最后合并多头并过输出投影得到 [B, Tq, D]. 面试写代码时要把维度写在注释里,重点说明 Q 的长度决定输出长度,K/V 的长度决定可被关注的条件范围。
Cross-Attention 的目标是让一个序列读取另一个序列的信息。目标序列产生 Q,条件序列产生 K 和 V。输出长度等于 Q 的长度,因为被更新的是目标序列;K/V 的长度只是提供可访问的记忆槽。
即使目标序列和条件序列原始维度不同,也可以通过 w_q、w_k、w_v 投影到相同的多头注意力维度。通常 embed_dim 要能被 num_heads 整除,head_dim = embed_dim / num_heads。
投影后把 q reshape 成 [B, H, Tq, Dh],k 和 v reshape 成 [B, H, Tk, Dh]。这样 q @ k.transpose(-2, -1) 的结果就是每个 head 内 Tq 对 Tk 的注意力分数。
如果 context 里有 padding token,要在 softmax 前把对应 key 位置 mask 掉;如果是 decoder 场景,还可能有因果 mask。mask 的形状要能 broadcast 到 [B, H, Tq, Tk],否则很容易维度错或泄漏无效 token。
softmax 后的权重乘 v,得到每个 query 位置聚合出的条件信息。再把多头从 [B, H, Tq, Dh] 转回 [B, Tq, H * Dh],经过输出投影得到 [B, Tq, embed_dim]。如果要和原始 x 做 residual,embed_dim 必须等于 q_dim,或额外投影回 q_dim。
代码不必写完整 Transformer block,但要包含线性投影、多头拆分、scaled dot-product、mask、dropout 可选、合并输出。最好在注释中标出形状,证明自己不是只背 API。
import math
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, q_dim, ctx_dim, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.h = num_heads
self.dh = embed_dim // num_heads
self.wq = nn.Linear(q_dim, embed_dim)
self.wk = nn.Linear(ctx_dim, embed_dim)
self.wv = nn.Linear(ctx_dim, embed_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def split(self, x):
b, t, _ = x.shape
return x.view(b, t, self.h, self.dh).transpose(1, 2)
def forward(self, x, context, key_mask=None):
# x: [B, Tq, Dq], context: [B, Tk, Dc]
q = self.split(self.wq(x)) # [B, H, Tq, Dh]
k = self.split(self.wk(context)) # [B, H, Tk, Dh]
v = self.split(self.wv(context)) # [B, H, Tk, Dh]
scores = q @ k.transpose(-2, -1) / math.sqrt(self.dh)
if key_mask is not None: # key_mask: [B, Tk], True means keep
scores = scores.masked_fill(~key_mask[:, None, None, :], -1e9)
attn = scores.softmax(dim=-1)
out = attn @ v # [B, H, Tq, Dh]
out = out.transpose(1, 2).contiguous()
out = out.view(x.size(0), x.size(1), self.h * self.dh)
return self.wo(out) 因为每个 query 位置从 context 中读取信息并更新自己,所以输出对应 query 序列,长度是 Tq。Tk 只是 key/value 的候选记忆数量。
用不同的线性层分别投影到共同的 attention embed_dim,再拆成相同 head_dim 的多头即可。
可以让文本 token 作为 Q,图像 token 作为 K/V,让文本读取视觉信息;也可以让 learnable query 读取视觉编码器输出,形成压缩后的视觉摘要。
常见是把 Tq 和 Tk 混了、k 没有 transpose、mask 维度不能 broadcast、softmax 维度写错,或合并多头前没有 contiguous。