真实面经题目 · 原创解析
mask attention是如何实现的?
Mask attention 的核心是在计算 attention 权重前,对不允许关注的位置加上一个极小值,使这些位置经过 softmax 后权重接近 0。它常用于因果语言建模、padding 屏蔽、局部注意力和结构化可见性约束。
Mask attention 是在 scaled dot-product attention 中加入 mask。原始注意力分数为 QK^T / sqrt(dk),然后对需要屏蔽的位置加上一个很大的负数,例如负无穷的近似值,再做 softmax。这样被 mask 的位置 softmax 后概率接近 0,后续与 V 加权求和时几乎不贡献信息。常见 mask 有 padding mask 和 causal mask。padding mask 防止模型关注补齐 token;causal mask 防止当前位置看到未来 token,是自回归语言模型生成能力的基础。
Mask 通常作用在 softmax 之前的 attention logits 上,而不是 softmax 之后随便置零。标准流程是先计算 Q 与 K 的点积分数,再缩放,然后叠加 mask,最后做 softmax 得到权重。这样可以保证被屏蔽位置在归一化时也不占概率质量。
实现上通常把允许关注的位置加 0,把不允许关注的位置加一个极大的负数。softmax 会把极大负数对应的指数值压到接近 0,因此这些位置不会参与有效加权。这个方式比直接改 V 更自然,因为它约束的是 token 之间的可见关系。
因果 mask 用于自回归模型,要求第 t 个位置只能关注自己和之前的位置,不能关注未来位置。它通常是一个上三角屏蔽矩阵。训练时虽然整段序列并行输入,但 causal mask 保证每个位置的预测只依赖过去信息,从而与逐 token 生成过程一致。
padding mask 用于处理变长序列批处理。为了把不同长度的样本拼成同一张量,短序列会补 padding token。模型不应把 padding 当作有效语义,因此需要在 attention 中屏蔽这些位置,避免无意义 token 干扰上下文表示。
实际实现中 mask 需要能广播到 attention logits 的形状,常见 logits 维度是 batch、head、query length、key length。padding mask 通常按 key 位置扩展,causal mask 按 query-key 位置形成矩阵,多头注意力里同一 mask 可以被多个 head 共享。
实现时要注意 dtype 和极小值选择。float32 中可以用很大的负数近似负无穷;低精度训练中如果数值过小可能产生溢出或 NaN。工程里常使用框架提供的 masked_fill、attention_mask 或 fused attention 接口来处理这些细节。
因为 softmax 会对所有位置归一化。如果先做 softmax 再置零,被屏蔽位置已经影响了概率分配;加在 softmax 前可以让这些位置不参与有效归一化。
可以。自回归模型在变长批处理时通常同时需要二者:causal mask 控制不能看未来,padding mask 控制不能看补齐位置。
数学上使用负无穷时是 0;工程中常用极大负数近似,softmax 后通常接近 0。在数值实现稳定的情况下,这些位置对输出没有实际贡献。
直接改 V 不能阻止这些位置参与 softmax 归一化,仍会改变其他位置的权重比例。mask logits 才是在注意力分布层面禁止可见性。