STAGE 2 生成模型

GAN 生成对抗网络

造假者 vs 鉴定师 — AI 生成内容的原理

📊 Excel 手推 ⏱️ 25 分钟 🎯 生成模型
1
📋 本章要点
  • GAN 由生成器和判别器对抗训练组成
  • 生成器学习从噪声映射到真实数据分布
  • 训练不稳定性:模式崩溃、训练振荡是主要挑战
  • WGAN 用 Wasserstein 距离替代 JS 散度,改善训练稳定性

什么是 GAN?

GAN(Generative Adversarial Network)的核心思想极其精巧:两个网络相互竞争,共同进步。 就像造假者和鉴定师的博弈 — 造假者不断提高伪造技术,鉴定师不断提升鉴别能力,最终造假者能以假乱真。

类比:想象一个艺术造假者(Generator 生成器)和一个艺术鉴定师(Discriminator 判别器)。 造假者学习画出逼真的画作,鉴定师学习分辨真假。两人都在不断提升,最终造假者的画作连鉴定师也分不出来。

GAN 的博弈论视角

🎨

Generator (G)

输入随机噪声 z,生成假数据 G(z)

目标:欺骗 D,让 D 把假数据判为真

🔍

Discriminator (D)

输入数据,判断是真数据还是假数据

目标:正确区分真假数据

Minimax 博弈目标函数:

minG maxD V(D, G) = E[log D(x)] + E[log(1 - D(G(z)))]

D 想最大化 V(正确分类),G 想最小化 V(欺骗 D)

训练流程概览

噪声 z
随机采样
G(z)
生成假数据
D
判断真假
Loss
计算损失
更新
交替更新 G 和 D
2

📊 对抗训练手推

我们用简化的一维数据来手推 GAN 的 3 轮对抗训练。真实数据分布为 x=5 附近的高斯分布。 点击单元格可以编辑输入值!

Round 1:训练判别器 D

步骤 计算 结果
真实数据 xreal 5.0 来自真实分布
噪声 z 0.5 随机噪声
G 生成 xfake z × 2 = 1.0 1.00
D(xreal) σ(0.5 × x_real - 1) 0.881
D(G(z)) σ(0.5 × x_fake - 1) 0.378
D 的损失 LD -[log(D(x_real)) + log(1-D(G(z)))] 0.564
D 准确率 D(x_real)>0.5 且 D(G(z))<0.5? 正确

Round 2:训练生成器 G

步骤 计算 结果
噪声 z 1.5 新的随机噪声
G 改进后 xfake z × 2 + 1 = 4.0 4.00
D(G(z)) σ(0.5 × x_fake - 1) 0.668
G 的损失 LG -log(D(G(z))) 0.403
D 是否被骗? D(G(z)) > 0.5? 被骗了!

Round 3:G 进一步改进

步骤 计算 结果
噪声 z 2.0 新的随机噪声
G 更强后 xfake z × 2 + 2 = 6.0 6.00
D(G(z)) σ(0.5 × x_fake - 1) 0.881
G 的损失 LG -log(D(G(z))) 0.127
G 的生成质量 x_fake 接近 x_real=5? 越来越接近!

观察:Round 1 中 D 正确区分了真假数据;Round 2 中 G 改进后开始能骗过 D; Round 3 中 G 生成的数据更接近真实分布,D 几乎分不出来。这就是 GAN 的对抗训练过程! 简化参数:G(z) = w×z + b(学习 w, b),D(x) = σ(0.5x - 1)。

3

🏗️ 经典架构演进

从 2014 年 Ian Goodfellow 提出原始 GAN 开始,GAN 的架构不断进化。

2014

Original GAN

Ian Goodfellow 在论文 "Generative Adversarial Networks" 中首次提出 GAN 的概念。 使用全连接层构建 G 和 D。

结构流程

噪声 z
~N(0,1)
G: FC
→ ReLU
→ FC → Sigmoid
假数据
G(z)
D: FC
→ ReLU
→ FC → Sigmoid
真/假
概率

Minimax 损失函数

minG maxD Ex~pdata[log D(x)] + Ez~pz[log(1 - D(G(z)))]
2015

DCGAN

DCGAN(Deep Convolutional GAN)用卷积层替代全连接层,大幅提升了图像生成质量。 提出了一系列架构设计准则,成为后续 GAN 研究的基础。

结构流程(G)

z
100d
FC
→ Reshape
DeConv
BN→ReLU
DeConv
BN→ReLU
DeConv
→ Tanh
64×64 图

关键设计准则

  • • 用 BatchNorm 替代 Dropout/MaxPool
  • • G 使用 ReLU/Tanh,D 使用 LeakyReLU
  • • 用 步长卷积 替代池化层
  • • 生成器最后一层用 Tanh,其余用 ReLU
2014

Conditional GAN (cGAN)

Conditional GAN 在 G 和 D 中都加入条件信息 y(如类别标签), 从而实现可控生成 — 指定生成什么类别的数据。

结构流程

噪声 z + 标签 y G(z, y) 假数据
标签 y + 数据 x D(x, y) 真/假

应用场景

  • 图像翻译:pix2pix — 语义图→真实照片
  • 文本到图像:根据文字描述生成图片
  • 类别控制:指定生成某类数字/物体
  • 图像修复:填补图像缺失区域
minG maxD E[log D(x|y)] + E[log(1 - D(G(z|y)|y))]

损失函数中加入了条件 y,G 和 D 都以 y 为条件

4

🎬 GAN 对抗训练动画

点击"播放"按钮,观察 G 和 D 如何交替训练。每一轮:先训练 D,再训练 G。

噪声 z
🎨
Generator G loss: --
🔍
Discriminator D acc: --
真实数据 x ~ N(5, 1)
点击播放开始训练
当前轮次
0 / 6
D 准确率
--
G 生成值
--
5

🎮 互动实验:噪声向量与生成分布

拖动滑块调整生成器 G 的参数,观察生成数据分布如何逼近真实数据分布。

1.0
0.0
1.0

观察:

  • 真实数据分布在 x=5 附近(蓝色)
  • 调整 w 和 b 让生成分布(粉色)逼近真实分布
  • 理想状态:w≈1, b≈5,生成数据集中在 5 附近
  • 噪声标准差控制生成数据的分散程度
6

💻 PyTorch 代码

simple_gan.py
import torch
import torch.nn as nn

# 生成器:噪声 → 假数据
class Generator(nn.Module):
    def __init__(self, noise_dim=100, output_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
        )

    def forward(self, z):
        return self.net(z)

# 判别器:数据 → 真/假概率
class Discriminator(nn.Module):
    def __init__(self, input_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)
train_gan.py
G = Generator(noise_dim=100)
D = Discriminator(input_dim=1)
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002)
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002)

for epoch in range(1000):
    # ===== 训练判别器 D =====
    real_data = torch.randn(32, 1) + 5  # 真实数据 ~ N(5, 1)
    z = torch.randn(32, 100)               # 噪声
    fake_data = G(z).detach()                 # G 生成假数据

    d_real = D(real_data)                     # D 对真实数据的判断
    d_fake = D(fake_data)                     # D 对假数据的判断
    loss_D = -(torch.log(d_real + 1e-8).mean() +
               torch.log(1 - d_fake + 1e-8).mean()) # D 的损失

    opt_D.zero_grad()
    loss_D.backward()
    opt_D.step()

    # ===== 训练生成器 G =====
    z = torch.randn(32, 100)
    fake_data = G(z)
    d_fake = D(fake_data)
    loss_G = -torch.log(d_fake + 1e-8).mean()  # G 的损失(想骗过 D)

    opt_G.zero_grad()
    loss_G.backward()
    opt_G.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}: D_loss={loss_D:.4f}, G_loss={loss_G:.4f}")
7

🧠 小测验

🔗 相关章节推荐
  • VAEVAE 是另一种生成模型,对比学习
  • CNNDCGAN 用卷积层构建生成器和判别器
  • Autoencoder理解生成模型的基础

1. GAN 的训练目标是什么?

2. Mode Collapse 是什么?

3. DCGAN 相比原始 GAN 的改进是什么?

8

📝 总结

🎨

对抗训练

G 和 D 的博弈:G 学习生成逼真数据,D 学习区分真假。两者交替训练,最终达到纳什均衡。

🏗️

架构演进

原始 GAN → DCGAN(卷积化)→ Conditional GAN(可控生成)。核心趋势:更强的生成能力、更可控的输出。

⚠️

训练挑战

Mode Collapse(模式崩塌)、训练不稳定、G 和 D 的平衡难以维持。后续研究提出了 WGAN 等改进方案。

🎯

核心思想

GAN 的核心是"对抗产生进步" — 通过竞争博弈,让生成模型学会数据的真实分布。

Easy Deep Learning · 用 Excel 手推理解每一个公式