基本了解
UNet是一种经典的卷积神经网络架构,解决了传统方法在数据量不足时面临的挑战。最初由医学图像分割任务提出,后被广泛应用于扩散模型(如DDPM、DDIM、Stable Diffusion)中作为噪声预测的核心网络。
核心结构包括一个收缩路径(downsampling path)和一个对称的扩展路径(upsampling path)。收缩路径通过多次下采样操作捕获上下文信息,而扩展路径则通过上采样操作结合底层特征和高层特征,实现精确的像素级分割。这种U形结构设计使其能够高效地利用有限的标注样本,并在现代GPU上快速执行。
以下是UNet的详细构成及其在扩散模型中的改进设计:
一、 UNet基础架构
1. 整体结构
UNet呈对称的U型结构,包含:
-
编码器(下采样路径):逐步提取高层语义特征,降低分辨率。
-
解码器(上采样路径):逐步恢复空间分辨率,结合编码器特征进行精确定位。
-
跳跃连接(Skip Connections):连接编码器与解码器对应层,保留细节信息。
2. 核心组件
层级 | 操作 | 作用 |
---|---|---|
编码器层 | 卷积 → 激活(如ReLU) → 池化/步长卷积 | 提取特征,压缩空间维度 |
解码器层 | 反卷积/插值上采样 → 跳跃连接 → 卷积 | 恢复分辨率,融合低级与高级特征 |
瓶颈层 | 深层卷积操作 | 捕获全局上下文信息 |
二、 扩散模型中UNet的改进
在扩散模型(如DDPM、DDIM)中,UNet经过以下关键改进:
1. 时间步嵌入(Timestep Embedding)
-
作用:将扩散过程的时间步信息注入网络,指导噪声预测。
-
实现:
-
时间步编码为向量(通过正弦位置编码或MLP)。
-
通过相加或拼接融入各层特征图。
-
2. 残差块(Residual Blocks)
-
结构:每个块包含多个卷积层 + 归一化层(如GroupNorm) + 激活函数。
-
改进:
-
引入残差连接,缓解梯度消失。
-
集成时间步嵌入(通过相加或自适应归一化)。
-
3. 注意力机制(Attention Layers)
-
位置:通常在瓶颈层或解码器中插入。
-
类型:
-
自注意力(Self-Attention):捕捉长程依赖。
-
交叉注意力(Cross-Attention):用于多模态模型(如Stable Diffusion中结合文本提示)。
-
4. 分组归一化(Group Normalization)
-
替代方案:相比BatchNorm,更适合小批量训练,提升稳定性。
5. 多尺度特征融合
-
跳跃连接增强:通过通道注意力(如Squeeze-and-Excitation)动态加权特征。
三、 典型扩散模型UNet结构示例
以Stable Diffusion为例,其UNet结构参数如下:
- 输入:带噪声的潜空间特征图(64×64×4)
- 编码器:4个下采样层,每层包含2个残差块
- 瓶颈层:2个残差块 + 自注意力层
- 解码器:4个上采样层,每层包含2个残差块
- 跳跃连接:逐层传递编码器特征至解码器
- 总参数量:约860M(Stable Diffusion 1.4版本)
四、 关键设计思想
-
分辨率保持:通过跳跃连接保留低级细节,避免上采样时的模糊问题。
-
动态条件注入:时间步嵌入和文本条件(如CLIP embedding)通过自适应归一化(AdaGN)融入网络:
-
# 示例:自适应归一化(AdaGN) def adaptive_group_norm(x, timestep_emb, scale_shift=True):scale, shift = timestep_emb.chunk(2, dim=1)x = GroupNorm(x)if scale_shift:x = x * (1 + scale) + shiftreturn x
-
-
轻量化设计
-
使用深度可分离卷积(Depthwise Separable Conv)减少计算量。
-
通道数动态调整(如Stable Diffusion中通道数从128到1024递增)。
-
五、 与传统UNet的区别
特性 | 传统UNet(医学分割) | 扩散模型UNet |
---|---|---|
输入/输出 | 原始图像 → 分割掩码 | 噪声图像 + 时间步 → 噪声残差 |
归一化方式 | BatchNorm | GroupNorm |
条件注入 | 无 | 时间步嵌入 + 文本/图像条件 |
注意力机制 | 无 | 自注意力/交叉注意力 |
参数量级 | 较小(几M~几十M) | 较大(几百M~上B) |
六、 代码框架示例(PyTorch风格)
class DiffusionUNet(nn.Module):def __init__(self):super().__init__()# 编码器self.encoder = nn.ModuleList([DownBlock(3, 64), # 下采样块DownBlock(64, 128),DownBlock(128, 256)])# 瓶颈层(含注意力)self.bottleneck = nn.Sequential(ResBlock(256, 512),SelfAttention(512),ResBlock(512, 512))# 解码器self.decoder = nn.ModuleList([UpBlock(512, 256), # 上采样块(含跳跃连接)UpBlock(256, 128),UpBlock(128, 64)])# 时间步嵌入self.time_embed = nn.Sequential(nn.Linear(128, 256),nn.SiLU(),nn.Linear(256, 256))def forward(self, x, t):t_emb = self.time_embed(t) # 时间步嵌入skips = []for down in self.encoder:x = down(x, t_emb) # 注入时间条件skips.append(x)x = self.bottleneck(x)for up in self.decoder:x = up(x, skips.pop(), t_emb)return x
DDPM中的应用
在扩散模型(如DDPM)中,改进的UNet结构通过以下方式整合时间嵌入(time embedding),实现对噪声的精准预测:
1. UNet的基础结构
编码器-解码器架构:UNet由对称的下采样(编码器)和上采样(解码器)路径组成,通过跳跃连接保留多尺度特征。
残差块(ResBlock):每个下采样和上采样阶段包含多个残差块,用于特征提取。
2. Time Embedding的作用
时间步编码:扩散过程中的时间步 t 被编码为高维向量(如通过正弦函数或MLP),表示当前加噪阶段。
动态调节网络:Time embedding作为条件信号,影响每一层的计算,使网络适应不同时间步的噪声分布。
3. Time Embedding的注入方式
特征图加法:Time embedding通过线性层投影后,直接添加到残差块的特征图中
自适应归一化(AdaGN):
4. 输入与输出的设计
输入:加噪图像在时间步t时由原始图像逐步添加高斯噪声得到。
输出:预测的噪声ϵθ (xt,t),目标是最小化与真实噪声的误差
5. 关键理解点
条件化每一层:每个残差块均接收相同的time embedding,确保网络在不同时间步采用不同的特征变换策略。
时间感知的噪声预测:通过时间嵌入,模型能区分早期(大尺度噪声)和晚期(细节噪声)的去噪需求,提升生成质量。
6. 扩散模型中的应用
前向过程:逐步为图像添加噪声
反向过程:UNet预测噪声,通过迭代去噪重建