CLIP 多模态对比学习
用对比学习连接图像与文本 — 零样本分类的魔法
- CLIP 通过对比学习将图像和文本映射到同一语义空间
- 4 亿图文对的预训练使模型具备强大的零样本分类能力
- 对比损失:拉近匹配对、推远不匹配对
- 成为多模态 AI(DALL-E、Stable Diffusion)的基础组件
什么是 CLIP?
CLIP(Contrastive Language-Image Pre-training)的核心思想:用对比学习连接图像和文本。它同时理解图片和文字,然后学会判断哪张图和哪段文字是"一对"。
OpenAI 在 2021 年发布了 CLIP,使用 4 亿个从互联网收集的图文对进行训练。不需要人工标注,直接用网页上的图片和配文作为训练数据。
🏠 生活类比
想象你教一个小孩认识世界:
CLIP 不只是分类——它理解图像和文本之间的语义关系,所以能做到"零样本":不需要任何特定类别的训练数据,就能识别新类别。
双编码器架构 — 图文理解
CLIP 的核心是两个独立的编码器,分别处理图像和文本,然后把它们映射到同一个向量空间。
(ViT / ResNet)
f_img ∈ R^d
最大化匹配对的相似度
(Transformer)
f_txt ∈ R^d
将图像转换为一个特征向量。可以用 ViT(Vision Transformer)或 ResNet。输出是一个固定长度的向量,比如 512 维。
将文本转换为同样维度的特征向量。使用 Transformer 架构,通过 [CLS] token 的输出作为整个句子的表示。
两个编码器的输出被映射到同一个向量空间,使得匹配的图文对在空间中距离很近,不匹配的则距离远。
输出向量会做 L2 归一化(除以自身的模),使得余弦相似度等价于点积,简化计算。
对比学习 — 余弦相似度手推
CLIP 的核心操作是计算图文对之间的余弦相似度。用 3 对图文来演示。
📐 余弦相似度公式
衡量两个向量的方向相似程度,值在 -1 到 1 之间,越大越相似
📋 输入:3 对图文的特征向量(简化为 2 维)
| 图文对 | 图像特征 (f_img) | 文本特征 (f_txt) |
|---|---|---|
| 对 1(猫图 + "一只猫") | [0.8, 0.6] | [0.7, 0.7] |
| 对 2(狗图 + "一只狗") | [-0.6, 0.8] | [-0.5, 0.9] |
| 对 3(车图 + "一辆汽车") | [0.9, -0.4] | [0.8, -0.6] |
💡 点击单元格可以编辑特征值,观察相似度矩阵如何变化!
📊 相似度矩阵(余弦相似度 × 100)
每一行是一个图像,每一列是一个文本。对角线上是正样本对(匹配的图文),其他位置是负样本。
| 文本 1 "一只猫" |
文本 2 "一只狗" |
文本 3 "一辆汽车" |
|
|---|---|---|---|
| 图像 1(猫图) | 99.0 | 21.7 | 55.5 |
| 图像 2(狗图) | 21.7 | 99.4 | -78.1 |
| 图像 3(车图) | 55.5 | -78.1 | 99.4 |
训练目标:
- • 对角线(正样本对)的相似度要尽量高
- • 非对角线(负样本对)的相似度要尽量低
- • 通过不断调整两个编码器的参数,让匹配的图文对越来越近,不匹配的越来越远
InfoNCE 损失 — Excel 手推
CLIP 使用 InfoNCE 损失来训练。对于每个图像,它要从所有文本中找到正确的配对。
📐 InfoNCE 损失公式
📋 手推:以图像 1(猫图)为例
图像 1 与三个文本的余弦相似度:[0.990, 0.217, 0.555],温度 τ = 0.07
| Step 4a:相似度除以温度 τ | ||||
| 文本 1 (正样本) | 文本 2 (负样本) | 文本 3 (负样本) | ||
|---|---|---|---|---|
| cos_sim | 0.990 | 0.217 | 0.555 | |
| cos_sim / τ | 14.143 | 3.100 | 7.929 | |
| Step 4b:计算 exp(sim/τ) | ||||
| 文本 1 (正样本) | 文本 2 (负样本) | 文本 3 (负样本) | 求和 | |
|---|---|---|---|---|
| exp(sim/τ) | 1389013.6 | 22.2 | 2775.8 | 1391811.6 |
| Step 4c:计算概率和损失 | |||
| 正样本概率 | 公式 | 损失值 | 解读 |
|---|---|---|---|
| 0.9980 | exp(正样本) / 求和 | 0.0020 | 损失很小 → 模型已经能正确匹配 |
损失 = -log(概率):概率越接近 1,损失越接近 0。训练目标就是让每个图像的正样本概率尽量大。
🌡️ 温度参数 τ 的作用
τ → 0
分布极其尖锐,只看最相似的对
τ = 0.07
CLIP 论文中学习到的值,效果最好
τ → ∞
分布趋近均匀,模型无法区分正负样本
CLIP 中 τ 是一个可学习的参数(初始化为 0.07),通过 log 缩放后由训练自动调整。实际实现中通常将 τ 限制在合理范围内。
🎮 互动实验 — 零样本分类
模拟 CLIP 的零样本分类过程:调整图像特征向量,观察它与不同文本描述的匹配分数如何变化。
💡 文本候选描述:
- • "一只猫" 特征: [0.70, 0.70]
- • "一只狗" 特征: [-0.50, 0.90]
- • "一辆汽车" 特征: [0.80, -0.60]
- • "一朵花" 特征: [-0.30, -0.80]
🎯 试试看:
- • 把滑块调到 [0.7, 0.7] → 匹配 "一只猫"
- • 把滑块调到 [0.8, -0.6] → 匹配 "一辆汽车"
- • 把滑块调到 [-0.3, -0.8] → 匹配 "一朵花"
代码实现 (PyTorch)
用 PyTorch 实现一个简化的 CLIP:两个编码器 + 对比损失。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCLIP(nn.Module):
def __init__(self, embed_dim=512):
super().__init__()
# 图像编码器(简化版:用 MLP 模拟)
self.image_encoder = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, embed_dim),
)
# 文本编码器(简化版:用 MLP 模拟)
self.text_encoder = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, embed_dim),
)
# 可学习的温度参数(log 缩放)
self.log_temperature = nn.Parameter(torch.log(torch.tensor(0.07)))
def forward(self, images, texts):
# 编码
img_features = self.image_encoder(images) # (N, embed_dim)
txt_features = self.text_encoder(texts) # (N, embed_dim)
# L2 归一化
img_features = F.normalize(img_features, dim=-1)
txt_features = F.normalize(txt_features, dim=-1)
# 计算余弦相似度矩阵 (N × N)
temperature = torch.exp(self.log_temperature)
logits = torch.matmul(img_features, txt_features.T) / temperature
# InfoNCE 损失(图像→文本 和 文本→图像 的对称损失)
labels = torch.arange(logits.size(0), device=logits.device)
loss_i2t = F.cross_entropy(logits, labels) # 图像找文本
loss_t2i = F.cross_entropy(logits.T, labels) # 文本找图像
loss = (loss_i2t + loss_t2i) / 2
return loss
# 训练循环
model = SimpleCLIP()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
images = torch.randn(32, 768) # 32 张图的特征
texts = torch.randn(32, 768) # 32 段文本的特征
loss = model(images, texts)
loss.backward()
optimizer.step()
optimizer.zero_grad()
小测验
- Transformer — CLIP 的双编码器都是 Transformer
- Loss & 反向传播 — 对比学习损失函数的设计
- GAN — 文本到图像生成的基础
1. CLIP 的训练数据是什么?
2. InfoNCE 损失的训练目标是什么?
3. CLIP 为什么能做到"零样本分类"?