STAGE 2 图像分割

UNet 图像分割

U 型编码-解码结构 + 跳跃连接 — 医学图像分割的里程碑

📊 Excel 手推 ⏱️ 20 分钟 🎯 语义分割
1
📋 本章要点
  • UNet 采用编码器-解码器结构 + 跳跃连接
  • 跳跃连接将编码器的高分辨率特征传递给解码器
  • 特别适合像素级预测任务:语义分割、医学图像
  • 成为后续扩散模型(Stable Diffusion)的核心骨架

什么是 UNet?

UNet 是一种专门用于图像分割的深度学习网络。它的核心思想是 U 型编码-解码结构 + 跳跃连接(Skip Connection)。 不同于图像分类(给整张图一个标签),图像分割需要给每个像素分配一个类别标签。

类比:想象你在看一张照片。先把它缩小成缩略图了解全局(编码器),然后再逐像素放大精确定位每个物体的边界(解码器)。 跳跃连接就像在放大的过程中,随时回头看原始照片的细节,确保不丢失边缘信息。

诞生背景

UNet 由 Olaf Ronneberger 等人于 2015 年在论文 "U-Net: Convolutional Networks for Biomedical Image Segmentation" 中提出, 最初是为医学图像分割设计的。医学图像的特点是训练数据少、需要精确的像素级分割,UNet 的跳跃连接正好解决了下采样过程中细节丢失的问题。

🏥

医学影像

器官、肿瘤、细胞分割

🛰️

遥感图像

建筑物、道路、土地分割

🚗

自动驾驶

行人、车辆、车道分割

UNet 的整体结构

输入图像 编码器(下采样) 瓶颈层 解码器(上采样) 分割输出
跳跃连接:编码器的每一层特征直接拼接到解码器对应层
2

📊 编码器(下采样)— Excel 手推

编码器通过卷积 + ReLU + 最大池化逐步缩小特征图尺寸,同时增加通道数以提取更高层特征。 点击单元格可以编辑输入值!

操作 输入尺寸
(C×H×W)
卷积核 输出尺寸
(C×H×W)
说明
输入 -- 1 -- 1×64×64 灰度医学图像
Conv Block 1 Conv3×3 + ReLU ×2 1×64×64 3×3, pad=1 64×64×64 提取低级特征(边缘)
Pool 1 MaxPool 2×2 64×64×64 stride=2 64×32×32 尺寸减半
Conv Block 2 Conv3×3 + ReLU ×2 64×32×32 3×3, pad=1 128×32×32 通道翻倍
Pool 2 MaxPool 2×2 128×32×32 stride=2 128×16×16 尺寸再减半
Conv Block 3 Conv3×3 + ReLU ×2 128×16×16 3×3, pad=1 256×16×16 高级特征
Pool 3 MaxPool 2×2 256×16×16 stride=2 256×8×8 进入瓶颈

观察:每次池化后特征图尺寸减半(64→32→16→8),通道数翻倍(64→128→256)。 这就是"下采样"的过程 — 用空间分辨率换取特征深度。

3

📊 瓶颈层 — Excel 手推

瓶颈层是 UNet 的最底层,特征图尺寸最小(8×8),但通道数最多。这里是网络"理解"图像全局语义的地方。

操作 输入尺寸
(C×H×W)
输出尺寸
(C×H×W)
说明
Conv Block 4 Conv3×3 + ReLU ×2 256×8×8 512×8×8 最高层语义特征

瓶颈层的作用

特征压缩

空间尺寸最小(8×8),但每个位置编码了丰富的语义信息,相当于对图像的"高度概括"。

过渡桥梁

连接编码器和解码器,将"压缩"后的全局特征传递给解码器进行逐步"解压"和细节恢复。

4

📊 解码器(上采样)+ 跳跃连接 — Excel 手推

解码器通过转置卷积(上采样)+ 拼接跳跃连接逐步恢复特征图的空间尺寸。 跳跃连接将编码器对应层的特征拼接过来,补充下采样中丢失的细节信息。 点击单元格可以编辑输入值!

操作 输入尺寸
(C×H×W)
跳跃连接 拼接后尺寸 输出尺寸
(C×H×W)
说明
Up Conv 1 ConvTranspose 2×2 512×8×8 cat(enc3: 256×16×16) 768×16×16 256×16×16 恢复到 16×16
Conv Block Up1 Conv3×3 + ReLU ×2 768×16×16 -- -- 256×16×16 融合跳跃特征
Up Conv 2 ConvTranspose 2×2 256×16×16 cat(enc2: 128×32×32) 384×32×32 128×32×32 恢复到 32×32
Conv Block Up2 Conv3×3 + ReLU ×2 384×32×32 -- -- 128×32×32 融合跳跃特征
Up Conv 3 ConvTranspose 2×2 128×32×32 cat(enc1: 64×64×64) 192×64×64 64×64×64 恢复到原始尺寸
Conv Block Up3 Conv3×3 + ReLU ×2 192×64×64 -- -- 64×64×64 最终特征融合
输出层 Conv 1×1 64×64×64 -- -- n_classes×64×64 像素级分类

跳跃连接的魔力:编码器 Pool1 的 64×64×64 特征直接拼接到 Up3 的 64×64×64, 补充了下采样中丢失的边缘、纹理等细节。这就是 UNet 能精确分割边界的秘密!

5

🎬 UNet 结构动画

点击"播放"按钮,观察数据在 UNet 的 U 型结构中如何流动:先沿编码器下采样,经过瓶颈层,再沿解码器上采样,同时跳跃连接传递特征。

输入
1×64×64
Enc1
64×64
Enc2
128×32
Enc3
256×16
瓶颈
512×8
跳3: Enc3→Up1 (cat)
跳2: Enc2→Up2 (cat)
跳1: Enc1→Up3 (cat)
Up1
256×16
Up2
128×32
Up3
64×64
输出
n×64×64
点击播放观察数据流动
当前阶段
--
特征图尺寸
--
6

💻 PyTorch 代码

unet_modules.py
import torch
import torch.nn as nn
import torch.nn.functional as F


# 基础卷积块:两次 3×3 卷积 + BN + ReLU
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


# 下采样块:MaxPool + DoubleConv
class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch),
        )

    def forward(self, x):
        return self.maxpool_conv(x)


# 上采样块:转置卷积 + 跳跃连接拼接 + DoubleConv
class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        # 转置卷积:通道减半,尺寸翻倍
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)  # 拼接后通道数 = in_ch

    def forward(self, x, skip):
        # 上采样
        x = self.up(x)
        # 跳跃连接:拼接编码器特征
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)
unet.py
class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=2):
        super().__init__()

        # 编码器(下采样路径)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)

        # 瓶颈层
        self.bottleneck = Down(512, 1024)

        # 解码器(上采样路径)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)

        # 输出层:1×1 卷积映射到类别数
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # 编码器:逐层下采样,保存每层特征
        x1 = self.inc(x)        # 64×64×64
        x2 = self.down1(x1)     # 128×32×32
        x3 = self.down2(x2)     # 256×16×16
        x4 = self.down3(x3)     # 512×8×8

        # 瓶颈层
        x5 = self.bottleneck(x4) # 1024×4×4

        # 解码器:逐层上采样 + 跳跃连接
        x = self.up1(x5, x4)    # cat(x4) → 512×8×8
        x = self.up2(x, x3)     # cat(x3) → 256×16×16
        x = self.up3(x, x2)     # cat(x2) → 128×32×32
        x = self.up4(x, x1)     # cat(x1) → 64×64×64

        # 输出:像素级分类
        return self.outc(x)       # n_classes×64×64
7

🧠 小测验

🔗 相关章节推荐
  • CNNUNet 基于 CNN 的编码器-解码器结构
  • AutoencoderUNet 的跳跃连接思想来源
  • GANPix2Pix 用 UNet 做图像翻译

1. UNet 最初是为解决什么问题而设计的?

2. UNet 中跳跃连接(Skip Connection)的主要作用是什么?

3. UNet 解码器中常用的上采样方法是什么?

+

📝 总结

🔽

编码器(下采样)

通过卷积+池化逐步缩小特征图,提取从低级到高级的层次化特征。空间尺寸减半,通道数翻倍。

🔼

解码器(上采样)

通过转置卷积逐步恢复空间分辨率,将高层语义特征"解压"回像素级别的预测。

🔗

跳跃连接

编码器的细节特征直接拼接到解码器,补充下采样丢失的边缘和纹理信息,实现精确的像素级分割。

🎯

核心思想

"先压缩理解全局,再解压恢复细节" — U 型结构 + 跳跃连接,让网络同时拥有全局视野和局部精度。

Easy Deep Learning · 用 Excel 手推理解每一个公式