STAGE 3 多模态对比学习

CLIP 多模态对比学习

用对比学习连接图像与文本 — 零样本分类的魔法

📊 Excel 手推 ⏱️ 30 分钟 🎯 对比学习
1
📋 本章要点
  • CLIP 通过对比学习将图像和文本映射到同一语义空间
  • 4 亿图文对的预训练使模型具备强大的零样本分类能力
  • 对比损失:拉近匹配对、推远不匹配对
  • 成为多模态 AI(DALL-E、Stable Diffusion)的基础组件

什么是 CLIP?

CLIP(Contrastive Language-Image Pre-training)的核心思想:用对比学习连接图像和文本。它同时理解图片和文字,然后学会判断哪张图和哪段文字是"一对"。

OpenAI 在 2021 年发布了 CLIP,使用 4 亿个从互联网收集的图文对进行训练。不需要人工标注,直接用网页上的图片和配文作为训练数据。

🏠 生活类比

想象你教一个小孩认识世界:

🖼️ → 📝
看图说话
给 AI 一张猫的照片,它能说出"这是一只猫"
📝 → 🖼️
看字找图
给 AI 一段文字"一只橘猫在沙发上睡觉",它能从一堆图里找到匹配的

CLIP 不只是分类——它理解图像和文本之间的语义关系,所以能做到"零样本":不需要任何特定类别的训练数据,就能识别新类别。

2

双编码器架构 — 图文理解

CLIP 的核心是两个独立的编码器,分别处理图像和文本,然后把它们映射到同一个向量空间。

图像输入
🖼️ 图像
图像编码器
(ViT / ResNet)
图像特征向量
f_img ∈ R^d
对比学习
计算余弦相似度
最大化匹配对的相似度
文本输入
📝 文本
文本编码器
(Transformer)
文本特征向量
f_txt ∈ R^d
🖼️ 图像编码器

将图像转换为一个特征向量。可以用 ViT(Vision Transformer)或 ResNet。输出是一个固定长度的向量,比如 512 维。

📝 文本编码器

将文本转换为同样维度的特征向量。使用 Transformer 架构,通过 [CLS] token 的输出作为整个句子的表示。

🔗 共享向量空间

两个编码器的输出被映射到同一个向量空间,使得匹配的图文对在空间中距离很近,不匹配的则距离远。

📐 L2 归一化

输出向量会做 L2 归一化(除以自身的模),使得余弦相似度等价于点积,简化计算。

3

对比学习 — 余弦相似度手推

CLIP 的核心操作是计算图文对之间的余弦相似度。用 3 对图文来演示。

📐 余弦相似度公式

cos(a, b) = a · b / (‖a‖ × ‖b‖)

衡量两个向量的方向相似程度,值在 -1 到 1 之间,越大越相似

📋 输入:3 对图文的特征向量(简化为 2 维)

图文对 图像特征 (f_img) 文本特征 (f_txt)
对 1(猫图 + "一只猫") [0.8, 0.6] [0.7, 0.7]
对 2(狗图 + "一只狗") [-0.6, 0.8] [-0.5, 0.9]
对 3(车图 + "一辆汽车") [0.9, -0.4] [0.8, -0.6]

💡 点击单元格可以编辑特征值,观察相似度矩阵如何变化!

📊 相似度矩阵(余弦相似度 × 100)

每一行是一个图像,每一列是一个文本。对角线上是正样本对(匹配的图文),其他位置是负样本

文本 1
"一只猫"
文本 2
"一只狗"
文本 3
"一辆汽车"
图像 1(猫图) 99.0 21.7 55.5
图像 2(狗图) 21.7 99.4 -78.1
图像 3(车图) 55.5 -78.1 99.4

训练目标:

  • 对角线(正样本对)的相似度要尽量
  • 非对角线(负样本对)的相似度要尽量
  • • 通过不断调整两个编码器的参数,让匹配的图文对越来越近,不匹配的越来越远
4

InfoNCE 损失 — Excel 手推

CLIP 使用 InfoNCE 损失来训练。对于每个图像,它要从所有文本中找到正确的配对。

📐 InfoNCE 损失公式

L = -log( exp(sim(x,y+)/τ) / Σ exp(sim(x,yi)/τ) )
sim(x, y+)
正样本对的相似度
Σ exp(sim/τ)
所有对的相似度指数之和
τ (温度)
控制分布的锐度

📋 手推:以图像 1(猫图)为例

图像 1 与三个文本的余弦相似度:[0.990, 0.217, 0.555],温度 τ = 0.07

Step 4a:相似度除以温度 τ
文本 1 (正样本) 文本 2 (负样本) 文本 3 (负样本)
cos_sim 0.990 0.217 0.555
cos_sim / τ 14.143 3.100 7.929
Step 4b:计算 exp(sim/τ)
文本 1 (正样本) 文本 2 (负样本) 文本 3 (负样本) 求和
exp(sim/τ) 1389013.6 22.2 2775.8 1391811.6
Step 4c:计算概率和损失
正样本概率 公式 损失值 解读
0.9980 exp(正样本) / 求和 0.0020 损失很小 → 模型已经能正确匹配

损失 = -log(概率):概率越接近 1,损失越接近 0。训练目标就是让每个图像的正样本概率尽量大。

🌡️ 温度参数 τ 的作用

❄️

τ → 0

分布极其尖锐,只看最相似的对

⚖️

τ = 0.07

CLIP 论文中学习到的值,效果最好

🔥

τ → ∞

分布趋近均匀,模型无法区分正负样本

CLIP 中 τ 是一个可学习的参数(初始化为 0.07),通过 log 缩放后由训练自动调整。实际实现中通常将 τ 限制在合理范围内。

5

🎮 互动实验 — 零样本分类

模拟 CLIP 的零样本分类过程:调整图像特征向量,观察它与不同文本描述的匹配分数如何变化。

0.50
0.50

💡 文本候选描述:

  • "一只猫" 特征: [0.70, 0.70]
  • "一只狗" 特征: [-0.50, 0.90]
  • "一辆汽车" 特征: [0.80, -0.60]
  • "一朵花" 特征: [-0.30, -0.80]

🎯 试试看:

  • • 把滑块调到 [0.7, 0.7] → 匹配 "一只猫"
  • • 把滑块调到 [0.8, -0.6] → 匹配 "一辆汽车"
  • • 把滑块调到 [-0.3, -0.8] → 匹配 "一朵花"
最佳匹配
一只猫
6

代码实现 (PyTorch)

用 PyTorch 实现一个简化的 CLIP:两个编码器 + 对比损失。

clip_simple.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCLIP(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        # 图像编码器(简化版:用 MLP 模拟)
        self.image_encoder = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, embed_dim),
        )
        # 文本编码器(简化版:用 MLP 模拟)
        self.text_encoder = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, embed_dim),
        )
        # 可学习的温度参数(log 缩放)
        self.log_temperature = nn.Parameter(torch.log(torch.tensor(0.07)))

    def forward(self, images, texts):
        # 编码
        img_features = self.image_encoder(images)    # (N, embed_dim)
        txt_features = self.text_encoder(texts)      # (N, embed_dim)

        # L2 归一化
        img_features = F.normalize(img_features, dim=-1)
        txt_features = F.normalize(txt_features, dim=-1)

        # 计算余弦相似度矩阵 (N × N)
        temperature = torch.exp(self.log_temperature)
        logits = torch.matmul(img_features, txt_features.T) / temperature

        # InfoNCE 损失(图像→文本 和 文本→图像 的对称损失)
        labels = torch.arange(logits.size(0), device=logits.device)
        loss_i2t = F.cross_entropy(logits, labels)     # 图像找文本
        loss_t2i = F.cross_entropy(logits.T, labels)   # 文本找图像
        loss = (loss_i2t + loss_t2i) / 2
        return loss

# 训练循环
model = SimpleCLIP()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(100):
    images = torch.randn(32, 768)  # 32 张图的特征
    texts  = torch.randn(32, 768)  # 32 段文本的特征
    loss = model(images, texts)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
7

小测验

🔗 相关章节推荐

1. CLIP 的训练数据是什么?

2. InfoNCE 损失的训练目标是什么?

3. CLIP 为什么能做到"零样本分类"?

Easy Deep Learning · 用 Excel 手推理解每一个公式