60 秒回答模板

Cross-Attention 和 Self-Attention 的主要差别是 Q/K/V 的来源。Self-Attention 里 Q、K、V 都来自同一个序列;Cross-Attention 里 Q 来自当前要更新的目标序列,比如 decoder token 或文本 token,K 和 V 来自外部条件序列,比如 encoder 输出、图像 token 或检索证据。实现时输入可以设成 x: [B, Tq, Dq],context: [B, Tk, Dc]。先分别线性投影得到 q、k、v,并统一到 num_heads * head_dim;再 reshape 成 [B, H, T, Dh];然后计算 scores = q @ k^T / sqrt(Dh),形状是 [B, H, Tq, Tk];如果有 padding mask 或因果约束就在 scores 上加负无穷;softmax 后乘 v 得到 [B, H, Tq, Dh];最后合并多头并过输出投影得到 [B, Tq, D]. 面试写代码时要把维度写在注释里,重点说明 Q 的长度决定输出长度,K/V 的长度决定可被关注的条件范围。

考点 Q 决定输出
难度 真实面经题
回答目标 讲清机制、训练与评估取舍

深入解析

01

先区分 Q/K/V 来源

Cross-Attention 的目标是让一个序列读取另一个序列的信息。目标序列产生 Q,条件序列产生 K 和 V。输出长度等于 Q 的长度,因为被更新的是目标序列;K/V 的长度只是提供可访问的记忆槽。

02

投影到统一多头空间

即使目标序列和条件序列原始维度不同,也可以通过 w_q、w_k、w_v 投影到相同的多头注意力维度。通常 embed_dim 要能被 num_heads 整除,head_dim = embed_dim / num_heads。

03

多头 reshape 决定矩阵乘法

投影后把 q reshape 成 [B, H, Tq, Dh],k 和 v reshape 成 [B, H, Tk, Dh]。这样 q @ k.transpose(-2, -1) 的结果就是每个 head 内 Tq 对 Tk 的注意力分数。

04

mask 要作用在 scores 上

如果 context 里有 padding token,要在 softmax 前把对应 key 位置 mask 掉;如果是 decoder 场景,还可能有因果 mask。mask 的形状要能 broadcast 到 [B, H, Tq, Tk],否则很容易维度错或泄漏无效 token。

05

输出合并回目标序列

softmax 后的权重乘 v,得到每个 query 位置聚合出的条件信息。再把多头从 [B, H, Tq, Dh] 转回 [B, Tq, H * Dh],经过输出投影得到 [B, Tq, embed_dim]。如果要和原始 x 做 residual,embed_dim 必须等于 q_dim,或额外投影回 q_dim。

06

面试实现要短而清楚

代码不必写完整 Transformer block,但要包含线性投影、多头拆分、scaled dot-product、mask、dropout 可选、合并输出。最好在注释中标出形状,证明自己不是只背 API。

python

最小 Multi-Head Cross-Attention

import math
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, q_dim, ctx_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.h = num_heads
        self.dh = embed_dim // num_heads
        self.wq = nn.Linear(q_dim, embed_dim)
        self.wk = nn.Linear(ctx_dim, embed_dim)
        self.wv = nn.Linear(ctx_dim, embed_dim)
        self.wo = nn.Linear(embed_dim, embed_dim)

    def split(self, x):
        b, t, _ = x.shape
        return x.view(b, t, self.h, self.dh).transpose(1, 2)

    def forward(self, x, context, key_mask=None):
        # x: [B, Tq, Dq], context: [B, Tk, Dc]
        q = self.split(self.wq(x))          # [B, H, Tq, Dh]
        k = self.split(self.wk(context))    # [B, H, Tk, Dh]
        v = self.split(self.wv(context))    # [B, H, Tk, Dh]

        scores = q @ k.transpose(-2, -1) / math.sqrt(self.dh)
        if key_mask is not None:            # key_mask: [B, Tk], True means keep
            scores = scores.masked_fill(~key_mask[:, None, None, :], -1e9)

        attn = scores.softmax(dim=-1)
        out = attn @ v                      # [B, H, Tq, Dh]
        out = out.transpose(1, 2).contiguous()
        out = out.view(x.size(0), x.size(1), self.h * self.dh)
        return self.wo(out)
  • Q 的长度 Tq 决定输出长度;本实现返回 [B, Tq, embed_dim],做 residual 时需让 embed_dim == q_dim 或再投影回 q_dim。
  • mask 应在 softmax 前作用到 scores 的 key 维度。

易错点

  • 把 Cross-Attention 写成 Self-Attention,Q/K/V 都来自同一个输入。
  • 认为输出长度等于 K/V 长度,没理解被更新的是 query 序列。
  • q @ k 时忘记转置 key 的最后两个维度。
  • softmax 维度写错,没有在 key 维度上归一化。
  • mask 加在 softmax 后,或者 mask 形状无法 broadcast 到 score 张量。
  • 多头合并时维度顺序错,导致输出 token 和 head 维混在一起。

面试官追问

Cross-Attention 的输出长度为什么不是 Tk?

因为每个 query 位置从 context 中读取信息并更新自己,所以输出对应 query 序列,长度是 Tq。Tk 只是 key/value 的候选记忆数量。

如果 Q 和 K/V 的原始维度不同怎么办?

用不同的线性层分别投影到共同的 attention embed_dim,再拆成相同 head_dim 的多头即可。

Cross-Attention 在多模态里怎么用?

可以让文本 token 作为 Q,图像 token 作为 K/V,让文本读取视觉信息;也可以让 learnable query 读取视觉编码器输出,形成压缩后的视觉摘要。

常见实现 bug 有哪些?

常见是把 Tq 和 Tk 混了、k 没有 transpose、mask 维度不能 broadcast、softmax 维度写错,或合并多头前没有 contiguous。