UNet 图像分割
U 型编码-解码结构 + 跳跃连接 — 医学图像分割的里程碑
- UNet 采用编码器-解码器结构 + 跳跃连接
- 跳跃连接将编码器的高分辨率特征传递给解码器
- 特别适合像素级预测任务:语义分割、医学图像
- 成为后续扩散模型(Stable Diffusion)的核心骨架
什么是 UNet?
UNet 是一种专门用于图像分割的深度学习网络。它的核心思想是 U 型编码-解码结构 + 跳跃连接(Skip Connection)。 不同于图像分类(给整张图一个标签),图像分割需要给每个像素分配一个类别标签。
类比:想象你在看一张照片。先把它缩小成缩略图了解全局(编码器),然后再逐像素放大精确定位每个物体的边界(解码器)。 跳跃连接就像在放大的过程中,随时回头看原始照片的细节,确保不丢失边缘信息。
诞生背景
UNet 由 Olaf Ronneberger 等人于 2015 年在论文 "U-Net: Convolutional Networks for Biomedical Image Segmentation" 中提出, 最初是为医学图像分割设计的。医学图像的特点是训练数据少、需要精确的像素级分割,UNet 的跳跃连接正好解决了下采样过程中细节丢失的问题。
医学影像
器官、肿瘤、细胞分割
遥感图像
建筑物、道路、土地分割
自动驾驶
行人、车辆、车道分割
UNet 的整体结构
📊 编码器(下采样)— Excel 手推
编码器通过卷积 + ReLU + 最大池化逐步缩小特征图尺寸,同时增加通道数以提取更高层特征。 点击单元格可以编辑输入值!
| 层 | 操作 | 输入尺寸 (C×H×W) |
卷积核 | 输出尺寸 (C×H×W) |
说明 |
|---|---|---|---|---|---|
| 输入 | -- | 1 | -- | 1×64×64 | 灰度医学图像 |
| Conv Block 1 | Conv3×3 + ReLU ×2 | 1×64×64 | 3×3, pad=1 | 64×64×64 | 提取低级特征(边缘) |
| Pool 1 | MaxPool 2×2 | 64×64×64 | stride=2 | 64×32×32 | 尺寸减半 |
| Conv Block 2 | Conv3×3 + ReLU ×2 | 64×32×32 | 3×3, pad=1 | 128×32×32 | 通道翻倍 |
| Pool 2 | MaxPool 2×2 | 128×32×32 | stride=2 | 128×16×16 | 尺寸再减半 |
| Conv Block 3 | Conv3×3 + ReLU ×2 | 128×16×16 | 3×3, pad=1 | 256×16×16 | 高级特征 |
| Pool 3 | MaxPool 2×2 | 256×16×16 | stride=2 | 256×8×8 | 进入瓶颈 |
观察:每次池化后特征图尺寸减半(64→32→16→8),通道数翻倍(64→128→256)。 这就是"下采样"的过程 — 用空间分辨率换取特征深度。
📊 瓶颈层 — Excel 手推
瓶颈层是 UNet 的最底层,特征图尺寸最小(8×8),但通道数最多。这里是网络"理解"图像全局语义的地方。
| 层 | 操作 | 输入尺寸 (C×H×W) |
输出尺寸 (C×H×W) |
说明 |
|---|---|---|---|---|
| Conv Block 4 | Conv3×3 + ReLU ×2 | 256×8×8 | 512×8×8 | 最高层语义特征 |
瓶颈层的作用
特征压缩
空间尺寸最小(8×8),但每个位置编码了丰富的语义信息,相当于对图像的"高度概括"。
过渡桥梁
连接编码器和解码器,将"压缩"后的全局特征传递给解码器进行逐步"解压"和细节恢复。
📊 解码器(上采样)+ 跳跃连接 — Excel 手推
解码器通过转置卷积(上采样)+ 拼接跳跃连接逐步恢复特征图的空间尺寸。 跳跃连接将编码器对应层的特征拼接过来,补充下采样中丢失的细节信息。 点击单元格可以编辑输入值!
| 层 | 操作 | 输入尺寸 (C×H×W) |
跳跃连接 | 拼接后尺寸 | 输出尺寸 (C×H×W) |
说明 |
|---|---|---|---|---|---|---|
| Up Conv 1 | ConvTranspose 2×2 | 512×8×8 | cat(enc3: 256×16×16) | 768×16×16 | 256×16×16 | 恢复到 16×16 |
| Conv Block Up1 | Conv3×3 + ReLU ×2 | 768×16×16 | -- | -- | 256×16×16 | 融合跳跃特征 |
| Up Conv 2 | ConvTranspose 2×2 | 256×16×16 | cat(enc2: 128×32×32) | 384×32×32 | 128×32×32 | 恢复到 32×32 |
| Conv Block Up2 | Conv3×3 + ReLU ×2 | 384×32×32 | -- | -- | 128×32×32 | 融合跳跃特征 |
| Up Conv 3 | ConvTranspose 2×2 | 128×32×32 | cat(enc1: 64×64×64) | 192×64×64 | 64×64×64 | 恢复到原始尺寸 |
| Conv Block Up3 | Conv3×3 + ReLU ×2 | 192×64×64 | -- | -- | 64×64×64 | 最终特征融合 |
| 输出层 | Conv 1×1 | 64×64×64 | -- | -- | n_classes×64×64 | 像素级分类 |
跳跃连接的魔力:编码器 Pool1 的 64×64×64 特征直接拼接到 Up3 的 64×64×64, 补充了下采样中丢失的边缘、纹理等细节。这就是 UNet 能精确分割边界的秘密!
🎬 UNet 结构动画
点击"播放"按钮,观察数据在 UNet 的 U 型结构中如何流动:先沿编码器下采样,经过瓶颈层,再沿解码器上采样,同时跳跃连接传递特征。
1×64×64
64×64
128×32
256×16
512×8
256×16
128×32
64×64
n×64×64
💻 PyTorch 代码
import torch
import torch.nn as nn
import torch.nn.functional as F
# 基础卷积块:两次 3×3 卷积 + BN + ReLU
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
# 下采样块:MaxPool + DoubleConv
class Down(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_ch, out_ch),
)
def forward(self, x):
return self.maxpool_conv(x)
# 上采样块:转置卷积 + 跳跃连接拼接 + DoubleConv
class Up(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
# 转置卷积:通道减半,尺寸翻倍
self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_ch, out_ch) # 拼接后通道数 = in_ch
def forward(self, x, skip):
# 上采样
x = self.up(x)
# 跳跃连接:拼接编码器特征
x = torch.cat([skip, x], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels=1, n_classes=2):
super().__init__()
# 编码器(下采样路径)
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
# 瓶颈层
self.bottleneck = Down(512, 1024)
# 解码器(上采样路径)
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
# 输出层:1×1 卷积映射到类别数
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# 编码器:逐层下采样,保存每层特征
x1 = self.inc(x) # 64×64×64
x2 = self.down1(x1) # 128×32×32
x3 = self.down2(x2) # 256×16×16
x4 = self.down3(x3) # 512×8×8
# 瓶颈层
x5 = self.bottleneck(x4) # 1024×4×4
# 解码器:逐层上采样 + 跳跃连接
x = self.up1(x5, x4) # cat(x4) → 512×8×8
x = self.up2(x, x3) # cat(x3) → 256×16×16
x = self.up3(x, x2) # cat(x2) → 128×32×32
x = self.up4(x, x1) # cat(x1) → 64×64×64
# 输出:像素级分类
return self.outc(x) # n_classes×64×64
🧠 小测验
- CNN — UNet 基于 CNN 的编码器-解码器结构
- Autoencoder — UNet 的跳跃连接思想来源
- GAN — Pix2Pix 用 UNet 做图像翻译
1. UNet 最初是为解决什么问题而设计的?
2. UNet 中跳跃连接(Skip Connection)的主要作用是什么?
3. UNet 解码器中常用的上采样方法是什么?
📝 总结
编码器(下采样)
通过卷积+池化逐步缩小特征图,提取从低级到高级的层次化特征。空间尺寸减半,通道数翻倍。
解码器(上采样)
通过转置卷积逐步恢复空间分辨率,将高层语义特征"解压"回像素级别的预测。
跳跃连接
编码器的细节特征直接拼接到解码器,补充下采样丢失的边缘和纹理信息,实现精确的像素级分割。
核心思想
"先压缩理解全局,再解压恢复细节" — U 型结构 + 跳跃连接,让网络同时拥有全局视野和局部精度。