真实面经题目 · 原创解析
CLIP 的图文对比学习流程如何用伪代码表示?
这题考 CLIP 图文对比学习的训练流程,回答重点是 batch 内配对、图像/文本归一化向量、相似度矩阵、温度系数和对称交叉熵损失。
CLIP 的伪代码可以按一个 batch 的图文对来写。输入是 B 对匹配的 image_i 和 text_i。先用 image_encoder 得到 image embedding,用 text_encoder 得到 text embedding;再对两个 embedding 做 L2 normalize,让点积等价于余弦相似度。然后计算 logits = image_emb @ text_emb.T / temperature,得到 [B, B] 的相似度矩阵,其中对角线是正样本,非对角线是 batch 内负样本。训练损失是双向的:image-to-text 方向要求每张图匹配自己的文本,text-to-image 方向要求每段文本匹配自己的图片,所以 loss = (CE(logits, labels) + CE(logits.T, labels)) / 2。温度系数控制 softmax 分布的尖锐程度,通常作为可学习或可调参数。面试里要强调 CLIP 不需要手工枚举负样本,同一个 batch 的其他图文天然成为负样本,但 batch 质量、配对噪声和温度会强烈影响训练效果。
每个训练 batch 包含 B 对图片和文本,默认第 i 张图片与第 i 段文本匹配。其他 B-1 段文本对这张图片就是负样本,反过来也一样。这个 batch 内负样本机制是 CLIP 训练效率的关键。
图像进入 image encoder,文本进入 text encoder,得到两个向量矩阵。两塔不需要在编码阶段交互,因此图像和文本可以独立编码,也方便后续检索系统离线建库。
对 embedding 做 L2 normalize 后,矩阵乘法 image_emb @ text_emb.T 就得到所有图文两两之间的余弦相似度。这个 [B, B] 矩阵的第 i 行表示第 i 张图和所有文本的匹配分数。
image-to-text 方向对每一行做交叉熵,目标标签是对角线;text-to-image 方向对转置矩阵做交叉熵。两者平均后,模型同时学会以图搜文和以文搜图。
温度越小,softmax 越尖锐,模型更强调区分最相近的负样本;温度太小可能训练不稳定,太大则区分信号变弱。实际中常用 logit_scale 或 temperature 做可学习/可调控制。
面试写完核心损失后,可以补充大 batch、数据去噪、图文配对质量、分布式 gather、hard negative 和避免 false negative。这些比只写几行矩阵乘法更能体现工程理解。
import torch
import torch.nn.functional as F
def clip_loss(images, texts, image_encoder, text_encoder, logit_scale):
img = F.normalize(image_encoder(images), dim=-1) # [B, D]
txt = F.normalize(text_encoder(texts), dim=-1) # [B, D]
logits = logit_scale.exp() * (img @ txt.t()) # [B, B]
labels = torch.arange(logits.size(0), device=logits.device)
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
return (loss_i2t + loss_t2i) / 2 因为 batch 中第 i 张图和第 i 段文本是一对正样本,相似度矩阵的对角线位置就是正确匹配,所以标签是 0 到 B-1。
归一化后点积主要反映方向相似度,也就是余弦相似度,避免向量范数主导匹配分数,使图文 embedding 更适合检索。
batch 越大,batch 内负样本越多,对比信号更强,但显存和通信成本更高,分布式训练还需要聚合不同卡上的 embedding。
batch 里的非配对样本未必真的不相关,例如两张相似商品图对应相似描述。把它们当强负样本可能伤害训练,需要通过数据清洗或软标签缓解。