这里主要使用pytorch实现基本的无条件去噪扩散模型,理论上面的推导这里不重点介绍。
原文理论参考:
前向过程和后向过程
扩散过程包括正向过程和反向过程。前向过程是基于噪声调度的预定马尔可夫链。噪声表是一组方差 ,它们控制构成马尔可夫链的条件正态分布。前向过程是按照预定好的noise scheduler 对干净图像()加入噪声,迭代生成一系列的噪声版本。
上面的公式是前向过程的数学表示,但直观上我们可以将其理解为一个序列,在该序列中我们逐渐将数据示例 X 映射到纯高斯噪声。在中间时间步长 t 处,我们得到了 X 的噪声版本,在最终时间步长 T 处,我们得到了近似受标准正态分布支配的纯噪声。当我们构建扩散模型时,我们需要选择噪声表。例如,在 DDPM 中,我们的噪声表具有从 1e-4 到 0.02 线性增加方差的 1000 个时间步长。同样重要的是要注意,我们的前向过程是静态的,这意味着我们选择noise scheduler作为扩散模型的超参数,并且我们不训练前向过程,因为它已经明确定义。
关于前向过程,一个关键代码实操细节是,因为分布是正态分布,所以我们可以在数学上推导一个称为“扩散核”的分布,它是给定初始数据点的前向过程中任何中间值的分布。这使我们能够绕过在前向过程中迭代添加 t-1 级噪声的所有中间步骤,以获得具有 t 时间处噪声的图像,这将在稍后训练模型时派上用场。这在数学上表示为:
其中, 是从时间点0到时间t的 的累积乘积,对照上面的公式,很快可以理解 的意义所在。
反向过程是扩散模型的关键。反向过程本质上是通过从纯噪声图像中逐渐去除大量噪声来生成新图像来逆推正向过程。我们从纯噪声数据开始,对于每个时间步 t,我们减去理论上该时间步的前向过程添加的噪声量。我们不断消除噪音,直到最终得到类似于原始数据分布的东西。大部分工作是训练一个模型来仔细近似前向过程,以便估计可以生成新样本的反向过程。
算法和训练目标
前向过程算法实现
为了训练这样的模型来估计反向扩散过程,需要遵循下面定义的图像中的算法:
- 从训练数据集中随机采样数据点
- 在噪声(方差)表上选择一个随机时间步长t
- 将该时间步长t对应的噪声添加到数据中,通过“扩散内核”模拟前向扩散过程
- 将加噪数据投入到模型中,模型预测出此时添加的噪声
- 计算预测噪声和实际噪声之间的均方误差,并通过该目标函数优化模型参数
- 重复
从数学上讲,算法中的确切公式一开始可能看起来有点奇怪,如果没有看到完整的推导,但直观上它是基于噪声调度的 alpha 值的扩散内核的重新参数化,它只是计算了预测噪声和添加到图像中的实际噪声的平方差。
如果我们的模型可以根据前向过程的特定时间步成功预测噪声,我们可以迭代地从时间步 T 处的噪声开始,并根据每个时间步逐渐消除噪声,直到恢复类似于生成样本的数据,使其符合原始数据分布。
采样算法(逆向过程)实现
1. 从标准正态分布生成随机噪声
对于从最后一个时间步开始并向后移动的每个时间步:
2. 通过估计逆向过程分布来更新 Z,其中平均值由上一步中的 Z 参数化,方差由我们的模型在该时间步估计的噪声参数化
3. 添加少量噪声以保持稳定性(解释如下)
4. 重复直到到达时间步 0,此时恢复最初的图像
采样和生成图像的算法在数学上可能看起来很复杂,但它直观地可以归结为一个迭代过程,我们从纯噪声开始,估计理论上在时间步 t 添加的噪声,然后将其减去。我们这样做直到得到生成的样本。应该注意的唯一小细节是在减去估计噪声后,我们添加少量噪声以保持过程稳定。例如,在迭代过程开始时一次性估计并减去噪声总量会导致生成样本非常不连贯,因此在实践中,经验表明,添加一点噪声并迭代每个时间步可以生成更好的数据样本。最后可以迭代得到时间步为1时的噪声数据,并且拿训练好的去噪模型预测噪声,将两者加权对减,得到最后的干净图像
基本去噪模型框架-UNET
DDPM 论文的作者使用最初为医学图像分割设计的 UNET 架构来构建模型来预测扩散反向过程的噪声。这里使用的UNET模型适用于 32x32 图像,非常适合 MNIST 等数据集,但该模型可以缩放以处理更高分辨率的数据。 UNET 有很多变体,但这里将构建的模型架构的概述如下图所示。
DDPM 的 UNET 与经典的 UNET 类似,因为它同时包含下采样流和上采样流,从而减轻了网络的计算负担,同时还具有两个流之间的跳跃连接,以合并来自浅层和浅层的信息。模型的深层特征。
DDPM UNET 和经典 UNET 之间的主要区别在于,DDPM UNET 的特点是关注 16x16 维层以及每个残差块中的正弦transformer嵌入。正弦transformer嵌入背后的含义是告诉模型我们尝试预测噪声的时间步长。这有助于模型通过加入噪声的时间位置的位置信息来预测每个时间步的噪声。例如,如果我们有一个噪声时间表,那么模型了解它需要预测噪声的加噪时间位置信息,可以帮助模型预测相应时间步长的噪声。对于那些还不熟悉 Transformer 架构的人来说,可以在这里找到有关注意力和嵌入的更多一般信息 :Attention is All You Need https://arxiv.org/abs/1706.03762
在模型的实现中,我们首先导入必要的库函数并编码我们的正弦函数完成对加噪时间步长的嵌入表示。直观上,正弦嵌入是不同的正弦和余弦频率,可以直接添加到我们的输入中,为模型提供额外的位置/顺序理解。从下图中可以看出,每个正弦波都是独一无二的,这将使模型了解其在噪声表中的位置。
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange #pip install einops
from typing import List
import random
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from timm.utils import ModelEmaV3 #pip install timm
from tqdm import tqdm #pip install tqdm
import matplotlib.pyplot as plt #pip install matplotlib
import torch.optim as optim
import numpy as npclass SinusoidalEmbeddings(nn.Module):def __init__(self, time_steps:int, embed_dim: int):super().__init__()position = torch.arange(time_steps).unsqueeze(1).float()div = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)embeddings[:, 0::2] = torch.sin(position * div)embeddings[:, 1::2] = torch.cos(position * div)self.embeddings = embeddingsdef forward(self, x, t):embeds = self.embeddings[t].to(x.device)return embeds[:, :, None, None]
定义UNET残差层
# Residual Blocks
class ResBlock(nn.Module):def __init__(self, C: int, num_groups: int, dropout_prob: float):super().__init__()self.relu = nn.ReLU(inplace=True)self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)self.dropout = nn.Dropout(p=dropout_prob, inplace=True)def forward(self, x, embeddings):x = x + embeddings[:, :x.shape[1], :, :]r = self.conv1(self.relu(self.gnorm1(x)))r = self.dropout(r)r = self.conv2(self.relu(self.gnorm2(r)))return r + x
在 DDPM 中,作者在 UNET 的每层(分辨率比例)使用 2 个残差块,对于 16x16 维度层,这里在两个残差块之间引入了经典的 Transformer 注意机制。我们现在将为 UNET 实现注意力机制:
注意力机制
class Attention(nn.Module):def __init__(self, C: int, num_heads:int , dropout_prob: float):super().__init__()self.proj1 = nn.Linear(C, C*3)self.proj2 = nn.Linear(C, C)self.num_heads = num_headsself.dropout_prob = dropout_probdef forward(self, x):h, w = x.shape[2:]x = rearrange(x, 'b c h w -> b (h w) c')x = self.proj1(x)x = rearrange(x, 'b L (C H K) -> K b H L C', K=3, H=self.num_heads)q,k,v = x[0], x[1], x[2]x = F.scaled_dot_product_attention(q,k,v, is_causal=False, dropout_p=self.dropout_prob)x = rearrange(x, 'b H (h w) C -> b h w (C H)', h=h, w=w)x = self.proj2(x)return rearrange(x, 'b h w C -> b C h w')
注意力的实现是非常直接的。我们重塑数据,将 h*w 维度组合成“序列”维度,就像 Transformer 模型的经典输入一样,而通道维度变成嵌入特征维度。在此实现中,我们利用 torch.nn.function.scaled_dot_product_attention,因为该实现包含 flash 注意力,这是注意力的优化版本,在数学上仍然相当于经典的transformer注意力。有关 Flash Attention 的更多信息可以参考这些论文:
Flash Attention https://arxiv.org/abs/2205.14135
Flash Attention https://arxiv.org/abs/2205.14135
最后,到这里,我们就可以定义一个完整的UNET层了:
class UnetLayer(nn.Module):def __init__(self, upscale: bool, attention: bool, num_groups: int, dropout_prob: float,num_heads: int,C: int):super().__init__()self.ResBlock1 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)self.ResBlock2 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)if upscale:self.conv = nn.ConvTranspose2d(C, C//2, kernel_size=4, stride=2, padding=1)else:self.conv = nn.Conv2d(C, C*2, kernel_size=3, stride=2, padding=1)if attention:self.attention_layer = Attention(C, num_heads=num_heads, dropout_prob=dropout_prob)def forward(self, x, embeddings):x = self.ResBlock1(x, embeddings)if hasattr(self, 'attention_layer'):x = self.attention_layer(x)x = self.ResBlock2(x, embeddings)return self.conv(x), x
如前所述,DDPM 中的每一层都有 2 个残差块,并且可能包含一个注意力机制,并且我们另外将嵌入传递到每个残差块中。此外,我们返回下采样或上采样值以及我们将存储并用于残差串联跳跃连接的先前值。
UNET模型
class UNET(nn.Module):def __init__(self,Channels: List = [64, 128, 256, 512, 512, 384],Attentions: List = [False, True, False, False, False, True],Upscales: List = [False, False, False, True, True, True],num_groups: int = 32,dropout_prob: float = 0.1,num_heads: int = 8,input_channels: int = 1,output_channels: int = 1,time_steps: int = 1000):super().__init__()self.num_layers = len(Channels)self.shallow_conv = nn.Conv2d(input_channels, Channels[0], kernel_size=3, padding=1)out_channels = (Channels[-1]//2)+Channels[0]self.late_conv = nn.Conv2d(out_channels, out_channels//2, kernel_size=3, padding=1)self.output_conv = nn.Conv2d(out_channels//2, output_channels, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.embeddings = SinusoidalEmbeddings(time_steps=time_steps, embed_dim=max(Channels))for i in range(self.num_layers):layer = UnetLayer(upscale=Upscales[i],attention=Attentions[i],num_groups=num_groups,dropout_prob=dropout_prob,C=Channels[i],num_heads=num_heads)setattr(self, f'Layer{i+1}', layer)def forward(self, x, t):x = self.shallow_conv(x)residuals = []for i in range(self.num_layers//2):layer = getattr(self, f'Layer{i+1}')embeddings = self.embeddings(x, t)x, r = layer(x, embeddings)residuals.append(r)for i in range(self.num_layers//2, self.num_layers):layer = getattr(self, f'Layer{i+1}')x = torch.concat((layer(x, embeddings)[0], residuals[self.num_layers-i-1]), dim=1)return self.output_conv(self.relu(self.late_conv(x)))
定义 noise scheduler
class DDPM_Scheduler(nn.Module):def __init__(self, num_time_steps: int=1000):super().__init__()self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)alpha = 1 - self.betaself.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)def forward(self, t):return self.beta[t], self.alpha[t]
返回 beta(方差)值和 alpha 值,因为训练和采样的公式都基于它们的数学推导来使用。
def set_seed(seed: int = 42):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsenp.random.seed(seed)random.seed(seed)
另外定义一个训练种子。这意味着,如果想重现特定的训练实例,可以使用一组种子,这样每次使用相同的种子时,随机权重和优化器初始化都是相同的。
模型训练和图像生成
使用MNIST数据来对模型进行训练。
def train(batch_size: int=64,num_time_steps: int=1000,num_epochs: int=15,seed: int=-1,ema_decay: float=0.9999, lr=2e-5,checkpoint_path: str=None):set_seed(random.randint(0, 2**32-1)) if seed == -1 else set_seed(seed)train_dataset = datasets.MNIST(root='./data', train=True, download=False,transform=transforms.ToTensor())train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)model = UNET().cuda()optimizer = optim.Adam(model.parameters(), lr=lr)ema = ModelEmaV3(model, decay=ema_decay)if checkpoint_path is not None:checkpoint = torch.load(checkpoint_path)model.load_state_dict(checkpoint['weights'])ema.load_state_dict(checkpoint['ema'])optimizer.load_state_dict(checkpoint['optimizer'])criterion = nn.MSELoss(reduction='mean')for i in range(num_epochs):total_loss = 0for bidx, (x,_) in enumerate(tqdm(train_loader, desc=f"Epoch {i+1}/{num_epochs}")):x = x.cuda()x = F.pad(x, (2,2,2,2))t = torch.randint(0,num_time_steps,(batch_size,))e = torch.randn_like(x, requires_grad=False)a = scheduler.alpha[t].view(batch_size,1,1,1).cuda()x = (torch.sqrt(a)*x) + (torch.sqrt(1-a)*e)output = model(x, t)optimizer.zero_grad()loss = criterion(output, e)total_loss += loss.item()loss.backward()optimizer.step()ema.update(model)print(f'Epoch {i+1} | Loss {total_loss / (60000/batch_size):.5f}')checkpoint = {'weights': model.state_dict(),'optimizer': optimizer.state_dict(),'ema': ema.state_dict()}torch.save(checkpoint, 'checkpoints/ddpm_checkpoint')
为了进行推理,直观上,我们只是逆转了前向过程。从纯噪声开始,现在训练的模型可以预测每个时间步的估计噪声,然后可以迭代生成全新的样本。噪声的每个不同起点,都可以生成不同的独特样本,该样本与原始数据分布相似但独特。本文并未推导出推论公式,但开头链接的参考文献可以帮助指导想要更深入理解的读者。
def display_reverse(images: List):fig, axes = plt.subplots(1, 10, figsize=(10,1))for i, ax in enumerate(axes.flat):x = images[i].squeeze(0)x = rearrange(x, 'c h w -> h w c')x = x.numpy()ax.imshow(x)ax.axis('off')plt.show()def inference(checkpoint_path: str=None,num_time_steps: int=1000,ema_decay: float=0.9999, ):checkpoint = torch.load(checkpoint_path)model = UNET().cuda()model.load_state_dict(checkpoint['weights'])ema = ModelEmaV3(model, decay=ema_decay)ema.load_state_dict(checkpoint['ema'])scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)times = [0,15,50,100,200,300,400,550,700,999]images = []with torch.no_grad():model = ema.module.eval()for i in range(10):z = torch.randn(1, 1, 32, 32)for t in reversed(range(1, num_time_steps)):t = [t]temp = (scheduler.beta[t]/( (torch.sqrt(1-scheduler.alpha[t]))*(torch.sqrt(1-scheduler.beta[t])) ))z = (1/(torch.sqrt(1-scheduler.beta[t])))*z - (temp*model(z.cuda(),t).cpu())if t[0] in times:images.append(z)e = torch.randn(1, 1, 32, 32)z = z + (e*torch.sqrt(scheduler.beta[t]))temp = scheduler.beta[0]/( (torch.sqrt(1-scheduler.alpha[0]))*(torch.sqrt(1-scheduler.beta[0])) )x = (1/(torch.sqrt(1-scheduler.beta[0])))*z - (temp*model(z.cuda(),[0]).cpu())images.append(x)x = rearrange(x.squeeze(0), 'c h w -> h w c').detach()x = x.numpy()plt.imshow(x)plt.show()display_reverse(images)images = []
def main():train(checkpoint_path='checkpoints/ddpm_checkpoint', lr=2e-5, num_epochs=75)inference('checkpoints/ddpm_checkpoint')if __name__ == '__main__':main()
使用上面列出的模型训练,进行 75 个 epoch 训练后,可以得到以下结果:
参考文献
- DDPM https://arxiv.org/abs/2006.11239
- Attention is All You Need https://arxiv.org/abs/1706.03762
- Flash Attention 2 https://arxiv.org/abs/2307.08691