扩散模型的理论较为复杂,论文公式与开源代码都难以理解。现有的教程大多侧重推导公式。为此,本文通过精简代码(约300行),从代码运行角度讲解扩散模型。
本文包括扩散模型的3项技术复现:
1.DDPM (Denoising Diffusion Probabilistic Models,去噪扩散概率模型,SDE-包含训练与推理)
2.DDIM (Denoising Diffusion Implicit Models,ODE-加速推理)
3.Classifier_free(以标签为条件的控制生成)
1. 训练-加噪过程
1.1 参数设置
- timesteps是加噪次数,默认为1000次,具体某一次视为一个时间步 t {t} t
步数越多过程越稳定,Mnist实验中100次以上能保证效果。但t在训练中是随机的,并没有时序性,即值为[0,timesteps]内的随机正整数
注:本文time_steps = 300
- α \alpha α是接近但小于1的递减时序序列,有timesteps个, α t \alpha_t αt是 α \alpha α序列第t个元素的值
本文 α \alpha α的取值范围是:[0.9997, 0.97]
- β \beta β是接近但大于0的递增时序序列,有 β = 1 − α \beta = 1- \alpha β=1−α. 同理, β t \beta_t βt是 β \beta β序列第t个元素的值
范围是:[0.0003, 0.03]
- α ˉ \bar{\alpha} αˉ是 α \alpha α的累积值,是一个1到0的递减序列(序列元素值不包含边界0和1).
假设alphas = [0.9, 0.8, 0.7] (即timesteps = 3), 则 α ˉ \bar{\alpha} αˉ = [0.9, 0.9 * 0.8, 0.9 * 0.8 * 0.7] = [0.9, 0.72, 0.504]
α ˉ t \bar{\alpha}_t αˉt则是第t步的序列值
- 1 − α ˉ \sqrt{1-\bar{\alpha}} 1−αˉ,与 α ˉ \bar{\alpha} αˉ相反,是一个0到1的递增序列(序列元素同样不包含边界0与1),
区别于 β ˉ \bar{\beta} βˉ,即 β \beta β的累积值,它是一个接近0的递减序列(代码中没有用到)
-
x 0 x_0 x0是数据集样本,即训练集的图像
-
ϵ \epsilon ϵ是服从标准正态分布的高维张量(tensor)样本, ϵ ∼ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0,1) ϵ∼N(0,1),即随机噪声
1.2 加噪方程
顺序加噪过程是一个Markov Chain如下:
x t = α t x t − 1 + 1 − α t ϵ x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1 - \alpha_t}\epsilon xt=αtxt−1+1−αtϵ
上式表示从 x 0 x_0 x0到 x t x_t xt是一个加噪的高斯过程,即对图像进行timesteps次加噪,通过 α ˉ t \bar{\alpha}_t αˉt与 1 − α ˉ t 1-\bar{\alpha}_t 1−αˉt控制图像和噪声的比例,
最终随着t的增长, α t \alpha_t αt的值变小,图像在信号的比例降低,而噪声的比例增大, 最终图像变为全高斯噪声。
但是,这个训练过程每次需要迭代timesteps次,训练效率太低(类似RNN), 这里可以简化,概率表述如下:
q ( x t ∣ x 0 ) ∼ N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t \mid x_0) \sim \mathcal{N}(\sqrt{\bar{\alpha}_t} \, x_0, (1 - \bar{\alpha}_t) \, \mathbf{I}) q(xt∣x0)∼N(αˉtx0,(1−αˉt)I)
即只需要 x 0 x_0 x0即可完成训练,其中每个sample都可以是随机的t,即有下述公式:
x t = α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon xt=αˉt⋅x0+1−αˉt⋅ϵ
该公式描述了训练过程,随机取batch_size个t,每个t加到上述公式即可得到前向过程的 x t x_t xt, 并与模型的输出做MSE即可。
1.3 加噪过程
1.3.1 数据预处理
- 训练模型每次送入batch_size个图片,每个图片分一个随机的正数t,t的值域为[0, timesteps],
因此,根据timesteps和batch_size可算出t值。计算代码如下:
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
- 根据timesteps, 算出 β \beta β, α \alpha α,进一步算出 α ˉ \bar{\alpha} αˉ (alphas_cumprod), 其shape = [timesteps].
def linear_beta_schedule(timesteps):scale = 1000 / timestepsbeta_start = 0.0003 * scale # 该值过小,去燥不充分beta_end = 0.03 * scale # 该值过小,生成胡乱条纹return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)betas = linear_beta_schedule(timesteps)tensor([0.0010, 0.0013, 0.0017, 0.0020, 0.0023, 0.0027, 0.0030, 0.0033, 0.0036,0.0040, 0.0043, 0.0046, 0.0050, 0.0053, 0.0056, 0.0060, 0.0063, 0.0066,0.0070, 0.0073, 0.0076, 0.0080, 0.0083, 0.0086, 0.0089, 0.0093, 0.0096,...0.0934, 0.0937, 0.0940, 0.0944, 0.0947, 0.0950, 0.0954, 0.0957, 0.0960,0.0964, 0.0967, 0.0970, 0.0974, 0.0977, 0.0980, 0.0983, 0.0987, 0.0990,0.0993, 0.0997, 0.1000], dtype=torch.float64)alphas = 1. - betastensor([0.9990, 0.9987, 0.9983, 0.9980, 0.9977, 0.9973, 0.9970, 0.9967, 0.9964,0.9960, 0.9957, 0.9954, 0.9950, 0.9947, 0.9944, 0.9940, 0.9937, 0.9934,0.9930, 0.9927, 0.9924, 0.9920, 0.9917, 0.9914, 0.9911, 0.9907, 0.9904,...0.9066, 0.9063, 0.9060, 0.9056, 0.9053, 0.9050, 0.9046, 0.9043, 0.9040,0.9036, 0.9033, 0.9030, 0.9026, 0.9023, 0.9020, 0.9017, 0.9013, 0.9010,0.9007, 0.9003, 0.9000], dtype=torch.float64)alphas_cumprod = torch.cumprod(self.alphas, axis=0)tensor([9.9900e-01, 9.9767e-01, 9.9601e-01, 9.9403e-01, 9.9172e-01, 9.8908e-01,9.8613e-01, 9.8286e-01, 9.7927e-01, 9.7537e-01, 9.7117e-01, 9.6666e-01,9.6185e-01, 9.5675e-01, 9.5136e-01, 9.4568e-01, 9.3973e-01, 9.3350e-01,...8.8150e-07, 7.9802e-07, 7.2218e-07, 6.5331e-07, 5.9079e-07, 5.3406e-07,4.8260e-07, 4.3594e-07, 3.9364e-07, 3.5532e-07, 3.2061e-07, 2.8919e-07,2.6075e-07, 2.3502e-07, 2.1175e-07, 1.9072e-07, 1.7171e-07, 1.5454e-07],dtype=torch.float64)
- 得到公式系数: α ˉ \sqrt{\bar{\alpha}} αˉ (sqrt_alphas_cumprod), 1 − α ˉ \sqrt{1-\bar{\alpha}} 1−αˉ (sqrt_one_minus_alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
tensor([9.9900e-01, 9.9767e-01, 9.9601e-01, 9.9403e-01, 9.9172e-01, 9.8908e-01,9.8613e-01, 9.8286e-01, 9.7927e-01, 9.7537e-01, 9.7117e-01, 9.6666e-01,9.6185e-01, 9.5675e-01, 9.5136e-01, 9.4568e-01, 9.3973e-01, 9.3350e-01,...8.8150e-07, 7.9802e-07, 7.2218e-07, 6.5331e-07, 5.9079e-07, 5.3406e-07,4.8260e-07, 4.3594e-07, 3.9364e-07, 3.5532e-07, 3.2061e-07, 2.8919e-07,2.6075e-07, 2.3502e-07, 2.1175e-07, 1.9072e-07, 1.7171e-07, 1.5454e-07],dtype=torch.float64)sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_one_minus_alphas_cumprod
tensor([0.0316, 0.0483, 0.0632, 0.0773, 0.0910, 0.1045, 0.1178, 0.1309, 0.1440,0.1569, 0.1698, 0.1826, 0.1953, 0.2080, 0.2205, 0.2331, 0.2455, 0.2579,0.2702, 0.2824, 0.2946, 0.3067, 0.3187, 0.3306, 0.3424, 0.3542, 0.3658,...1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,1.0000, 1.0000, 1.0000], dtype=torch.float64)
- 计算t步的 α ˉ t \sqrt{\bar{\alpha}_t} αˉt (sqrt_alphas_cumprod_t) 和 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1−αˉt (sqrt_one_minus_alphas_cumprod_t)
由于一个batch_size中的图片t的值是随机的,且需要取每个t在 α ˉ \bar{\alpha} αˉ的值, 类似key-value索引,因此构造extract函数
该函数取每个样本的t(当作key)在 α ˉ \sqrt{\bar{\alpha}} αˉ对应的值(value),并reshape:
def _extract(a: torch.FloatTensor, t: torch.LongTensor, x_shape):# get the param of given timestep tbatch_size = t.shape[0]out = a.to(t.device).gather(0, t).float()out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))return out
这里:
gather()函数实现t到噪声强度值的key-value索引: sqrt_alphas_cumprod.gather(0,t)
得到的强度值reshape为 [batch_size, channels, 1 , 1], 方便加到图像中。
最后通过q_sample函数算出 α ˉ t \sqrt{\bar{\alpha}_t} αˉt (sqrt_alphas_cumprod_t) 和 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1−αˉt (sqrt_one_minus_alphas_cumprod_t)
函数代码如下:
def q_sample(x_start: torch.FloatTensor, t: torch.LongTensor, noise=None): # 前向加噪过程:forward diffusion (using the nice property): q(x_t | x_0)sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
举例:batch_size = 5, 随机得到一组t, 可从[0, timesteps]内随机采样得到 = [ 75, 112, 268, 207, 90], 索引+reshape后的值为:
sqrt_alphas_cumprod_t = tensor([[[[0.5979]]],[[[0.3268]]],[[[0.0018]]],[[[0.0234]]],[[[0.4814]]]])
以上每个数值对应一个样本的t,t.shape = [5,1,1,1]; 假设图像x.shape = [5, 3, 32, 32], 即一个batch_size有5张32x32的3通道RGB彩色图像
而x+t会触发广播机制:即一个样本(图像)的每个元素(像素点)都会加上对应的t,[1,1,1] broadcast to [3,32,32]
1.3.2 训练
- 构造x_T, 即一个与x_0相同shape的高斯噪声 ϵ \epsilon ϵ
noise = torch.randn_like(x_start) # random noise ~ N(0, 1)
- 通过q_sample函数算出batch_size个样本从 x 0 x_0 x0到 x t x_t xt的高斯噪声估计
就是前面提到的公式:
x t = α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon xt=αˉt⋅x0+1−αˉt⋅ϵ
- 两者做MSE
像u-net模型送入 x t x_t xt, 与对应的t,输出得到估计噪声 ϵ θ \epsilon_\theta ϵθ,
ϵ \epsilon ϵ与 ϵ θ \epsilon_\theta ϵθ做MSE, 完整代码如下:
def train_losses(model, x_start: torch.FloatTensor, t: torch.LongTensor):noise = torch.randn_like(x_start) # random noise ~ N(0, 1)x_noisy = self.q_sample(x_start, t, noise=noise) # x_t ~ q(x_t | x_0)predicted_noise = model(x_noisy, t) # predict noise from noisy imageloss = F.mse_loss(noise, predicted_noise)return loss
2. DDPM推理 - 去噪过程
去噪过程即图像生成过程,将相同shape的高斯噪声去燥成为图。
与训练不同,这里的t是逆序的, 且DDPM的timesteps不能跳步,需要执行timesteps次,即由 x T x_T xT到 x 0 x_0 x0
2.1 参数设置
- α ˉ t − 1 \bar{\alpha}_{t-1} αˉt−1 (alphas_cumprod_prev),时间步 t − 1 t-1 t−1 的累积 α \alpha α 值,这里首位元素填充1,保证长度和 α ˉ t \bar{\alpha}_{t} αˉt一致。
用于计算后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} | x_t, x_0) q(xt−1∣xt,x0) 的均值( μ t − 1 \mu_{t-1} μt−1)和方差( σ t − 1 2 \sigma_{t-1}^2 σt−12)
具体实现代码:
alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
- μ t − 1 \mu_{t-1} μt−1 (posterior_mean)
计算公式如下:
μ t − 1 = α ˉ t − 1 β t 1 − α ˉ t x 0 + ( 1 − α ˉ t − 1 ) α t 1 − α ˉ t x t \mu_{t-1} = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0 + \frac{(1 - \bar{\alpha}_{t-1})\sqrt{\alpha_t}}{1 - \bar{\alpha}_t} x_t μt−1=1−αˉtαˉt−1βtx0+1−αˉt(1−αˉt−1)αtxt
其中 α ˉ t − 1 β t 1 − α ˉ t \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} 1−αˉtαˉt−1βt 记为posterior_mean_coef1 ( x 0 的系数 x_0的系数 x0的系数), ( 1 − α ˉ t − 1 ) α t 1 − α ˉ t ) \frac{(1 - \bar{\alpha}_{t-1})\sqrt{\alpha_t}}{1 - \bar{\alpha}_t}) 1−αˉt(1−αˉt−1)αt) 记为posterior_mean_coef2 ( x t 的系数 x_t的系数 xt的系数)
这里的 x 0 x_0 x0是当前时刻t算出的估计值,不是数据集样本,计算两个coef的代码如下:
posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
- σ t − 1 2 \sigma_{t-1}^2 σt−12 (posterior_variance)
σ t − 1 2 = β t ( 1 − α ˉ t − 1 ) 1 − α ˉ t \sigma_{t-1}^2 = \frac{\beta_t (1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} σt−12=1−αˉtβt(1−αˉt−1)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
这里 σ t − 1 2 \sigma_{t-1}^2 σt−12非常小,直接存储可能造成下溢,导致无法计算梯度,因此计算 log σ t − 1 2 \log{\sigma^2_{t-1}} logσt−12
posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
- x 0 x_0 x0 是在t时刻图像的估计值,最后时刻t的去噪图像才是最终结果(img)
x 0 = x t α ˉ t − 1 − α ˉ t α ˉ t ⋅ ϵ θ x_0 = \frac{x_t}{\sqrt{\bar{\alpha}_t}} - \frac{\sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \cdot \epsilon_\theta x0=αˉtxt−αˉt1−αˉt⋅ϵθ
代码是:
pre_x_0 = _extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * model(x_t, t) # pred_noise = model(x_t, t)
其中 ϵ θ \epsilon_\theta ϵθ为模型的输出,即t步时预测的噪声,如果将上式的 x 0 x_0 x0带入开始的 μ t − 1 \mu_{t-1} μt−1公式,可以得到 x t x_t xt到 x t − 1 x_{t-1} xt−1完整公式:
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) + σ t ⋅ z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right) + \sigma_t \cdot z xt−1=αt1(xt−1−αˉt1−αt⋅ϵθ(xt,t))+σt⋅z
简化为:
x t − 1 = μ t − 1 + σ t − 1 ⋅ ϵ x_{t-1} = \mu_{t-1} + \sigma_{t-1} \cdot \epsilon xt−1=μt−1+σt−1⋅ϵ
这里的 ϵ \epsilon ϵ是随机噪声
2.2 去噪方程
在代码中通过对 x t x_t xt去燥,得到 x t − 1 x_{t-1} xt−1的表达式为:
$x_{t-1} = \mu_{t-1} + mask \cdot e{\frac{1}{2}\log{\sigma_{t-1}2}} \cdot \epsilon $
其中:
- 为有效训练, 用 e 1 2 log σ t − 1 2 e^{\frac{1}{2}\log{\sigma_{t-1}^2}} e21logσt−12 替换 σ t − 1 \sigma_{t-1} σt−1 (两者理论值相等),
目的是避免直接用 σ t − 1 \sigma_{t-1} σt−1造成数值过小而产生下溢(nan),无法计算梯度。
- mask是一个掩码矩阵,shape = [batch_size, 1, 1, 1],等t=0时,其元素的全部值为0,t不为0时,元素值为1
目的是t=0时不加噪,即 x 0 = μ 0 x_{0} = \mu_0 x0=μ0
mask代码是:
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) # no noise when t == 0
2.3 去噪过程
- 通过噪声累积值 α t ˉ \bar{\alpha_t} αtˉ填充得到 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt−1, 算出方差 σ t − 1 2 \sigma^2_{t-1} σt−12, 并转为 log σ t − 1 2 \log{\sigma^2_{t-1}} logσt−12
假设初始噪音为 x_t.shape = (batch_size, channels, image_size, image_size),
得到batch_size个 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt−1和 log σ t − 1 2 \log{\sigma^2_{t-1}} logσt−12代码如下:
def q_posterior_mean_variance(self, x_start: torch.FloatTensor, x_t: torch.FloatTensor, t: torch.LongTensor):# Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)posterior_mean = (self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clipped
- 通过初始噪音 x t x_t xt与 α t ˉ \bar{\alpha_t} αtˉ, 以及模型输出 ϵ θ \epsilon_\theta ϵθ, 估计出 x 0 x_0 x0, 这里 x 0 x_0 x0需要裁剪到合理值域{-1,1}
这里调用了上一步的q_posterior_mean_variance函数,因为要先算 x 0 x_0 x0 (x_start)才能得到 μ t − 1 \mu_{t-1} μt−1 (model_mean)
def p_mean_variance(self, model, x_t: torch.FloatTensor, t: torch.LongTensor):# compute x_0 from x_t and pred noise: the reverse of `q_sample`, 估计值,包含部分残留噪声pre_x_0 = self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * model(x_t, t) # pred_noise = model(x_t, t)pre_x_0 = torch.clamp(pre_x_0, min=-1., max=1.) # clip_denoisedmodel_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(pre_x_0, x_t, t) ## compute predicted mean and variance of p(x_{t-1} | x_t), predict noise using modelreturn model_mean, posterior_variance, posterior_log_variance
- 通过 x t x_t xt 与 x 0 x_0 x0,以及参数 α t ˉ , β t ˉ , α ˉ t − 1 \bar{\alpha_t},\bar{\beta_t},\bar{\alpha}_{t-1} αtˉ,βtˉ,αˉt−1组成的系数, 算出均值 μ t − 1 \mu_{t-1} μt−1
这个放在了第一个函数里:
posterior_mean = (self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)
- 通过公式计算t-1时刻的 x t − 1 x_{t-1} xt−1, 并重复timesteps次到 x 0 x_0 x0
单次代码为:
def p_sample(self, model, x_t: torch.FloatTensor, t: torch.LongTensor):# denoise_step: sample x_{t-1} from x_t and pred_noise, predict mean and variancemodel_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t)noise = torch.randn_like(x_t)nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) # no noise when t == 0pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise # # compute x_{t-1}return pred_img
从T时刻,即构造高斯噪声 x T x_T xT开始,执行timesteps次的循环为 (注意这里时逆向reverse,即从T到0):
def sample(self, model: nn.Module, image_size, batch_size=8, channels=3):shape = (batch_size, channels, image_size, image_size) # denoise: reverse diffusiondevice = next(model.parameters()).device# start from pure noise (for each example in the batch)img = torch.randn(shape, device=device) # x_T ~ N(0, 1)imgs = []for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):t = torch.full((batch_size,), i, device=device, dtype=torch.long)img = self.p_sample(model, img, t)imgs.append(img.cpu().numpy())return imgs
3.DDIM
总的来说,DDIM是DDPM的跳步采样,简化了采样公式,可以加快采样速度。
为简化代码,这里省略了论文公式中DDPM+DDIM混合采样的参数(即去掉了公式中DDPM的可选随机项),仅保留DDIM采样参数。
3.1 去噪方程
DDPM每一步都需要一个随机噪声 ϵ \epsilon ϵ,即SDE随机采样,引入随机噪声的好处是增加多样性,但timesteps次数多(通常要100-1000步)采样慢。
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) + σ t ⋅ ϵ = α ˉ t − 1 ⋅ x 0 + 1 − α ˉ t − 1 ⋅ x t − α ˉ t ⋅ x 0 1 − α ˉ t + σ t ⋅ ϵ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right) + \sigma_t \cdot \epsilon = \sqrt{\bar{\alpha}_{t-1}} \cdot x_0 + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} \cdot x_0}{\sqrt{1 - \bar{\alpha}_t}} + \sigma_t \cdot \epsilon xt−1=αt1(xt−1−αˉt1−αt⋅ϵθ(xt,t))+σt⋅ϵ=αˉt−1⋅x0+1−αˉt−1⋅1−αˉtxt−αˉt⋅x0+σt⋅ϵ
DDIM每一步不需要引入噪声,是一种确定性的ODE采样,类似跳步采样的方法,可以自定义步数,通常10-50步就能生成效果。
x t − 1 = α ˉ t − 1 ⋅ x 0 + 1 − α ˉ t − 1 ⋅ x t − α ˉ t ⋅ x 0 1 − α ˉ t x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot x_0 + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} \cdot x_0}{\sqrt{1 - \bar{\alpha}_t}} xt−1=αˉt−1⋅x0+1−αˉt−1⋅1−αˉtxt−αˉt⋅x0
其中 x 0 x_0 x0是通过模型输出的t时刻 ϵ θ \epsilon_\theta ϵθ计算得到的估计值
3.2 去噪过程
3.2.1 采样参数
- x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I), 随机噪声
shape = (batch_size, channels, image_size, image_size)
x_T = torch.randn(shape, device=self.betas.device) # start from pure noise
- 时间间隔c, 用于从原ddpm的timesteps中抽取等距的跳步数索引(index)
c = ddpm_timesteps / ddim_timesteps
- 根据c,抽取间隔序列 T T T (ddim_timestep_seq) 与 T p r e T_{pre} Tpre (ddim_timestep_prev_seq) 序列。
其中 T p r e T_{pre} Tpre序列是移除 T T T的最后一个元素后,在序列首位补上第一个元素(值为0)
ddim_timestep_seq = torch.tensor(list(range(0, self.timesteps, c))) + 1 # one from first scale to data during sampling
ddim_timestep_prev_seq = torch.cat((torch.tensor([0]), ddim_timestep_seq[:-1])) # previous sequence
3.2.2 单次去噪
- 用 T T T与 T p r e T_{pre} Tpre 索引去噪循环的第 t 与 t-1步
t = torch.full((batch_size,), ddim_timestep_seq[i], device=x_T.device, dtype=torch.long)
next_t = torch.full((batch_size,), ddim_timestep_prev_seq[i], device=x_T.device, dtype=torch.long)
- 将根据 α ˉ \bar{\alpha} αˉ 求 α t ˉ \bar{\alpha_t} αtˉ, α t − 1 ˉ \bar{\alpha_{t-1}} αt−1ˉ
alpha_cumprod_t = self._extract(self.alphas_cumprod, t, x_T.shape) #1. get current and previous alpha_cumprod
alpha_cumprod_t_prev = self._extract(self.alphas_cumprod, next_t, x_T.shape)
- 输出U-Net模型的预测噪声 ϵ θ \epsilon_\theta ϵθ, 第一次是输入 T T T 与 x T x_T xT
pred_noise = model(x_t, t)
- 根据t时刻的 ϵ θ \epsilon_\theta ϵθ估计 x 0 x_0 x0
公式与DDPM一致:
x 0 = x t α ˉ t − 1 − α ˉ t α ˉ t ⋅ ϵ θ x_0 = \frac{x_t}{\sqrt{\bar{\alpha}_t}} - \frac{\sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \cdot \epsilon_\theta x0=αˉtxt−αˉt1−αˉt⋅ϵθ
代码:
pred_x0 = (xs[-1] - torch.sqrt(1 - alpha_cumprod_t) * pred_noise) / torch.sqrt(alpha_cumprod_t)
pred_x0 = torch.clamp(pred_x0, min=-1., max=1.) # 3. get the predicted x_0, 预测 x_0
- 根据 x 0 x_0 x0, x t x_t xt, α t ˉ \bar{\alpha_t} αtˉ, α t − 1 ˉ \bar{\alpha_{t-1}} αt−1ˉ 计算出去燥值 x t − 1 x_{t-1} xt−1
pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev) * pred_noise # 5. compute "direction pointing to x_t" of formula (12)
x_t_pre = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt # 6. compute x_{t-1} of formula (12)
3.3 代码
循环单次去燥代码即可完成DDIM去噪过程,
完整的去噪代码如下:
#DDIM Inference/Reverse
def ddim_sample(self, model, image_size, ddim_timesteps=100, batch_size=8, channels=3):shape = (batch_size, channels, image_size, image_size)x_T = torch.randn(shape, device=self.betas.device) # start from pure noisexs = [x_T]c = self.timesteps // ddim_timesteps # make ddim timestep sequenceddim_timestep_seq = torch.tensor(list(range(0, self.timesteps, c))) + 1 # one from first scale to data during samplingddim_timestep_prev_seq = torch.cat((torch.tensor([0]), ddim_timestep_seq[:-1])) # previous sequencefor i in tqdm(reversed(range(0,ddim_steps)), desc='ddpm sampling loop time step', total=ddim_steps):t = torch.full((batch_size,), ddim_timestep_seq[i], device=x_T.device, dtype=torch.long)next_t = torch.full((batch_size,), ddim_timestep_prev_seq[i], device=x_T.device, dtype=torch.long)alpha_cumprod_t = self._extract(self.alphas_cumprod, t, x_T.shape) #1. get current and previous alpha_cumprodalpha_cumprod_t_prev = self._extract(self.alphas_cumprod, next_t, x_T.shape)pred_noise = model(xs[-1], t) # 2. predict noise using model, 模型预测噪声pred_x0 = (xs[-1] - torch.sqrt(1 - alpha_cumprod_t) * pred_noise) / torch.sqrt(alpha_cumprod_t)pred_x0 = torch.clamp(pred_x0, min=-1., max=1.) # 3. get the predicted x_0, 预测 x_0pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev) * pred_noise # 5. compute "direction pointing to x_t" of formula (12)x_t_pre = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt # 6. compute x_{t-1} of formula (12)xs.append(x_t_pre) # omit 4. compute variance: "sigma_t(η)" -> see formula (16) / σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)return xs
4. 模型结构
多数开源代码的U-Net结构较为复杂,包含较多的Attention, ResNet等。
本文设计一个简化的U-Net结构,仅保留必要部分:
4.1 下采样块 (Downsample)
4 blocks: 每个block由 1个 conv layer 构成,每两个conv layer 完成一次下采样
这里有两种conv layer,一种保持特征shape,一种下采样,具体代码如下:
class Upsample(nn.Module):def __init__(self, channels, num_groups=32):super().__init__()self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)self.num_groups = num_groupsdef forward(self, x):x = F.interpolate(x, scale_factor=2, mode="nearest") # # 上采样x = self.conv(x) # # 卷积 + GroupNormreturn x # 激活函数down_block1 = nn.Conv2d(io_channels, model_channels, kernel_size=3, padding=1) # down blocks
down_block2 = Downsample(model_channels)
4.2 中间块 (Middle)
- block
这里仅用1个block,包含2个conv layer,并且用resnet结构:
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.shortcut = (nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity())def forward(self, x):h = F.relu(F.group_norm(self.conv1(x), num_groups=32)) # 第一层卷积 + GroupNorm + 激活h = F.relu(F.group_norm(self.conv2(h), num_groups=32))return h + self.shortcut(x) # 残差连接middle_block = ResidualBlock(model_channels*2, model_channels*2) # middle block
- 注入时序信息
U-Net在输入 x x x时,需要对特征每一个元素广播对应的时序 t t t的值, 否则无法实现生成效果
多数开源代码是将t注入到每一个layer。这里为了简化模型,仅将 t t t注入到中间层(middle layer)部分,
具体是将t嵌入三角函数
4.3 时序注入模块(time_embedding)
4.3.1 代码
具体输入t和dim,输出embedding,代码如下:
def timestep_embedding(t, dim, max_period=10000):freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim // 2, dtype=torch.float32) / (dim // 2)).to(device=t.device)args = t[:, None].float() * freqs[None]embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)return embedding
输入:t.shape = [batch_size], dim = layer_channels,
输出: embedding.shape = [batch_size, layer_channels]
4.3.2 计算过程
假定输入:
timesteps = torch.tensor([0, 1, 2, 3])
dim = 8
1. 计算 half
half = dim 2 = 8 2 = 4 \text{half} = \frac{\text{dim}}{2} = \frac{8}{2} = 4 half=2dim=28=4
2. 计算频率 freqs
生成频率值:
freqs [ i ] = e − log ( max_period ) ⋅ i half \text{freqs}[i] = e^{-\log(\text{max\_period}) \cdot \frac{i}{\text{half}}} freqs[i]=e−log(max_period)⋅halfi
对于 half=4
和 max_period=10000
:
freqs = [ e − log ( 10000 ) ⋅ 0 / 4 , e − log ( 10000 ) ⋅ 1 / 4 , e − log ( 10000 ) ⋅ 2 / 4 , e − log ( 10000 ) ⋅ 3 / 4 ] \text{freqs} = \left[ e^{-\log(10000)\cdot 0/4 }, \ e^{-\log(10000) \cdot 1/4}, \ e^{-\log(10000) \cdot 2/4}, \ e^{-\log(10000) \cdot 3/4} \right] freqs=[e−log(10000)⋅0/4, e−log(10000)⋅1/4, e−log(10000)⋅2/4, e−log(10000)⋅3/4]
结果:
freqs = [ 1.0 , 0.1 , 0.01 , 0.001 ] \text{freqs} = [1.0, 0.1, 0.01, 0.001] freqs=[1.0,0.1,0.01,0.001]
3. 计算 args
用 timesteps
和 freqs
生成输入参数 args
:
args [ i , j ] = timesteps [ i ] ⋅ freqs [ j ] \text{args}[i, j] = \text{timesteps}[i] \cdot \text{freqs}[j] args[i,j]=timesteps[i]⋅freqs[j]
args = torch.tensor([0, 1, 2, 3])[:, None] * torch.tensor([1.0, 0.1, 0.01, 0.001])[None, :]
结果:
args = [ 0.00 0.00 0.00 0.000 1.00 0.10 0.01 0.001 2.00 0.20 0.02 0.002 3.00 0.30 0.03 0.003 ] \text{args} = \begin{bmatrix} 0.00 & 0.00 & 0.00 & 0.000 \\ 1.00 & 0.10 & 0.01 & 0.001 \\ 2.00 & 0.20 & 0.02 & 0.002 \\ 3.00 & 0.30 & 0.03 & 0.003 \end{bmatrix} args= 0.001.002.003.000.000.100.200.300.000.010.020.030.0000.0010.0020.003
4. 计算正弦和余弦嵌入
分别对 args
应用正弦和余弦函数:
-
余弦部分:
cos_part = cos ( args ) \text{cos\_part} = \cos(\text{args}) cos_part=cos(args)
-
正弦部分:
sin_part = sin ( args ) \text{sin\_part} = \sin(\text{args}) sin_part=sin(args)
具体计算:
余弦部分:
cos_part = [ cos ( 0.00 ) cos ( 0.00 ) cos ( 0.00 ) cos ( 0.000 ) cos ( 1.00 ) cos ( 0.10 ) cos ( 0.01 ) cos ( 0.001 ) cos ( 2.00 ) cos ( 0.20 ) cos ( 0.02 ) cos ( 0.002 ) cos ( 3.00 ) cos ( 0.30 ) cos ( 0.03 ) cos ( 0.003 ) ] = [ 1.0000 1.0000 1.0000 1.0000 0.5403 0.9950 0.9999 1.0000 − 0.4161 0.9801 0.9998 1.0000 − 0.9899 0.9553 0.9996 1.0000 ] \text{cos\_part} = \begin{bmatrix} \cos(0.00) & \cos(0.00) & \cos(0.00) & \cos(0.000) \\ \cos(1.00) & \cos(0.10) & \cos(0.01) & \cos(0.001) \\ \cos(2.00) & \cos(0.20) & \cos(0.02) & \cos(0.002) \\ \cos(3.00) & \cos(0.30) & \cos(0.03) & \cos(0.003) \end{bmatrix}= \begin{bmatrix} 1.0000 & 1.0000 & 1.0000 & 1.0000 \\ 0.5403 & 0.9950 & 0.9999 & 1.0000 \\ -0.4161 & 0.9801 & 0.9998 & 1.0000 \\ -0.9899 & 0.9553 & 0.9996 & 1.0000 \end{bmatrix} cos_part= cos(0.00)cos(1.00)cos(2.00)cos(3.00)cos(0.00)cos(0.10)cos(0.20)cos(0.30)cos(0.00)cos(0.01)cos(0.02)cos(0.03)cos(0.000)cos(0.001)cos(0.002)cos(0.003) = 1.00000.5403−0.4161−0.98991.00000.99500.98010.95531.00000.99990.99980.99961.00001.00001.00001.0000
正弦部分:
sin_part = [ sin ( 0.00 ) sin ( 0.00 ) sin ( 0.00 ) sin ( 0.000 ) sin ( 1.00 ) sin ( 0.10 ) sin ( 0.01 ) sin ( 0.001 ) sin ( 2.00 ) sin ( 0.20 ) sin ( 0.02 ) sin ( 0.002 ) sin ( 3.00 ) sin ( 0.30 ) sin ( 0.03 ) sin ( 0.003 ) ] = [ 0.0000 0.0000 0.0000 0.0000 0.8415 0.0998 0.0100 0.0010 0.9093 0.1987 0.0200 0.0020 0.1411 0.2955 0.0300 0.0030 ] \text{sin\_part} = \begin{bmatrix} \sin(0.00) & \sin(0.00) & \sin(0.00) & \sin(0.000) \\ \sin(1.00) & \sin(0.10) & \sin(0.01) & \sin(0.001) \\ \sin(2.00) & \sin(0.20) & \sin(0.02) & \sin(0.002) \\ \sin(3.00) & \sin(0.30) & \sin(0.03) & \sin(0.003) \end{bmatrix}= \begin{bmatrix} 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.8415 & 0.0998 & 0.0100 & 0.0010 \\ 0.9093 & 0.1987 & 0.0200 & 0.0020 \\ 0.1411 & 0.2955 & 0.0300 & 0.0030 \end{bmatrix} sin_part= sin(0.00)sin(1.00)sin(2.00)sin(3.00)sin(0.00)sin(0.10)sin(0.20)sin(0.30)sin(0.00)sin(0.01)sin(0.02)sin(0.03)sin(0.000)sin(0.001)sin(0.002)sin(0.003) = 0.00000.84150.90930.14110.00000.09980.19870.29550.00000.01000.02000.03000.00000.00100.00200.0030
5. 拼接结果
将 cos_part
和 sin_part
沿最后一个维度拼接:
embedding = [ cos_part , sin_part ] \text{embedding} = [\text{cos\_part}, \text{sin\_part}] embedding=[cos_part,sin_part]
6. 输出
embedding = [ 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.5403 0.9950 0.9999 1.0000 0.8415 0.0998 0.0100 0.0010 − 0.4161 0.9801 0.9998 1.0000 0.9093 0.1987 0.0200 0.0020 − 0.9899 0.9553 0.9996 1.0000 0.1411 0.2955 0.0300 0.0030 ] \text{embedding} = \begin{bmatrix} 1.0000 & 1.0000 & 1.0000 & 1.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.5403 & 0.9950 & 0.9999 & 1.0000 & 0.8415 & 0.0998 & 0.0100 & 0.0010 \\ -0.4161 & 0.9801 & 0.9998 & 1.0000 & 0.9093 & 0.1987 & 0.0200 & 0.0020 \\ -0.9899 & 0.9553 & 0.9996 & 1.0000 & 0.1411 & 0.2955 & 0.0300 & 0.0030 \end{bmatrix} embedding= 1.00000.5403−0.4161−0.98991.00000.99500.98010.95531.00000.99990.99980.99961.00001.00001.00001.00000.00000.84150.90930.14110.00000.09980.19870.29550.00000.01000.02000.03000.00000.00100.00200.0030
总结layer 的time_embedding操作:
通过三角函数,输出t时刻下,不同尺度的两类三角函数值: α cos ( t ) + α sin ( t ) \alpha\cos(t)+\alpha\sin(t) αcos(t)+αsin(t), α ∈ [ 1 , 0 , 10 , 001 , . . . ] \alpha \in [1, 0,1 0,001,...] α∈[1,0,10,001,...] (尺度数量由layer channels决定)
4.3.3 时序嵌入
将时序t的嵌入特征t_embedding广播到x的特征值,
由于middle层的特征通道数是原通道的2倍,因此这里用一个fc (noise_embedding) 将t_embedding特征映射为2倍,具体代码如下:
noise_embedding = nn.Linear(model_channels, model_channels*2) # noise block
middle = middle_block(x) # Middle block
noise_t = F.relu(self.noise_embedding(timestep_embedding(t,self.model_channels)))
middle = middle + noise_t[:, :, None, None]
4.4. 下采样块及总结
下采样块和上采样块类似,整型的U-Net结构总结如下:
- 实现U-Net核心部分,即下采样块链接下采样块的等尺寸特征
- 仅1个中间特征块 (middle block),且仅在该block处使用time_embedding 和 resnet
- 未使用attention结构
完整代码如下:
class Upsample(nn.Module):def __init__(self, channels, num_groups=32):super().__init__()self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)self.num_groups = num_groupsdef forward(self, x):x = F.interpolate(x, scale_factor=2, mode="nearest") # # 上采样x = self.conv(x) # # 卷积 + GroupNormreturn x # 激活函数class Downsample(nn.Module):def __init__(self, channels, num_groups=32):super().__init__()self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)self.num_groups = num_groupsdef forward(self, x):x = self.conv(x) #卷积 + GroupNormreturn x # 激活函数class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.shortcut = (nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity())def forward(self, x):h = F.relu(F.group_norm(self.conv1(x), num_groups=32)) # 第一层卷积 + GroupNorm + 激活h = F.relu(F.group_norm(self.conv2(h), num_groups=32))return h + self.shortcut(x) # 残差连接def timestep_embedding(t, dim, max_period=10000):freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim // 2, dtype=torch.float32) / (dim // 2)).to(device=t.device)args = t[:, None].float() * freqs[None]embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)return embeddingclass UNetModel(nn.Module):def __init__(self, io_channels=3, model_channels=128):super().__init__()self.model_channels = model_channelsself.down_block1 = nn.Conv2d(io_channels, model_channels, kernel_size=3, padding=1) # down blocksself.down_block2 = Downsample(model_channels)self.down_block3 = nn.Conv2d(model_channels, model_channels*2, kernel_size=3, padding=1)self.down_block4 = Downsample(model_channels*2)self.middle_block = ResidualBlock(model_channels*2, model_channels*2) # middle blockself.noise_embedding = nn.Linear(model_channels, model_channels*2) # noise blockself.up_block1 = Upsample(model_channels*2) # # up blocksself.up_block2 = nn.Conv2d(model_channels*2, model_channels, kernel_size=3, padding=1)self.up_block3 = Upsample(model_channels)self.up_block4 = nn.Conv2d(model_channels, io_channels, kernel_size=3, padding=1)def forward(self, x, t):x1 = F.relu(F.group_norm(self.down_block1(x), num_groups=32)) #Encode-Downsampling x2 = F.relu(F.group_norm(self.down_block2(x1), num_groups=32))x3 = F.relu(F.group_norm(self.down_block3(x2), num_groups=32))x4 = F.relu(F.group_norm(self.down_block4(x3), num_groups=32))middle = self.middle_block(x4) # Middle blocknoise_t = F.relu(self.noise_embedding(timestep_embedding(t,self.model_channels)))middle = middle + noise_t[:, :, None, None]x5 = F.relu(F.group_norm(self.up_block1(middle + x4), num_groups=32)) # Decode-Upsamplingx6 = F.relu(F.group_norm(self.up_block2(x5 + x3 ), num_groups=32)) # x7 = F.relu(F.group_norm(self.up_block3(x6 + x2), num_groups=32)) # out = self.up_block4(x7 + x1) # return out
5. 条件生成 (Classifier-Free)
条件生成在生成对抗网络GAN中就有多种实现方式, 包括CGAN,ACGAN/PCGAN,InfoGAN等,通过语义标签生成对应内容,训练都需要将标签 label 作为输入。
大体可以分为三种方法:
- 强监督:模型输出标签的预判,通过添加多标签分类损失训练
- 弱监督:仅添加标签作为特征输入模型,嵌入图像特征(相加),但不改变训练过程(不额外输出、不添加损失函数)
- 无监督:通过输入假标签(pseudo label), 这种标签是按某一特征规律自动生成的,且通常需要对应的损失函数(分类、聚类、对比学习等)
Classifier-Free是用的弱监督方法,即仅将标签作为特征嵌入进图像 x x x即可,这里用了最简单的方式,类似t的嵌入, 直接将label_embedding注入到U-Net的middle层,
再广播到每一个元素即可, 即:
x = x + t_embedding + label_embedding
由于label是离散序列,这里用embedding layer 而非 fc, 具体代码如下:
class UNetModel(nn.Module):...self.down_block4 = Downsample(model_channels*2)self.middle_block = ResidualBlock(model_channels*2, model_channels*2) # middle blockself.noise_embedding = nn.Linear(model_channels, model_channels*2) # noise blockself.class_emb = nn.Embedding(class_num, model_channels*2)self.up_block1 = Upsample(model_channels*2) # # up blocks...def forward(self, x, t, label=None):...x4 = F.relu(F.group_norm(self.down_block4(x3), num_groups=32))middle = self.middle_block(x4) # Middle blocknoise_t = F.relu(self.noise_embedding(timestep_embedding(t,self.model_channels)))c_emb = F.relu(self.class_emb(label))middle = middle + noise_t[:, :, None, None] + c_emb[:, :, None, None]x5 = F.relu(F.group_norm(self.up_block1(middle + x4), num_groups=32)) # Decode-Upsampling...out = self.up_block4(x7 + x1) # return out
代码中的c_emb就是图像对应的label嵌入变量
6. 实验结果
本文测试了 MNIST, Fashion-MNIST, Cifar-10三个数据集
6.1 参数设置
-
MNIST, Fashion-MNIST
- epoch = 200
- timesteps =300
-
Cifar-10
- epoch = 500
- timesteps = 1000
- 训练Cifar-10增大了模型层数(下采样、中间块、上采样各加2 cnn layers)
6.1 DDPM
- MNIST
- Fashion-MNIST
- Cifar-10
由于Cifar-10数据集本身类别的特征差异较大(有飞机、青蛙、骑车。。),且图像质量不清晰(32x32),因此合成Cifar-10在GAN中一直是挑战,这也是扩散模型的亮点。
由于本文模型较小,这里训练Cifar-10增大了模型层数,450epoch的结果如图所示:
(可以进一步提升模型以提升效果,如加cnn、attention layer)
6.2 DDIM
这里DDIM仅设置为50步, 实验发现最少10步可以就有效果,但效果不如DDPM。
Fashion-MNIST结果如下:
6.3 Classifier-Free
以下是增加Classifier-Free的改动代码后,输入label的条件生成结果(DDIM-10步去噪):
7.参考文献
7.1 本文代码
- DDPM(不到300行):
https://github.com/disanda/GM/blob/main/DDPM-DDIM-ClassifierFree/ddpm.py
- 预训练模型(300 timesteps的 MNIST 以及 FashionMNIST):
https://github.com/disanda/GM/tree/main/DDPM-DDIM-ClassifierFree/pre-trained-models
- 条件生成以及Cifar网络加层
https://github.com/disanda/GM/tree/main/DDPM-DDIM-ClassifierFree
7.2 参考代码
- https://github.com/ermongroup/ddim/blob/main/functions/denoising.py
- https://github.com/LinXueyuanStdio/PyTorch-DDPM
- https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm
- https://github.com/BastianChen/ddpm-demo-pytorch
- https://github.com/tatakai1/classifier_free_ddim/blob/main/Classifier_Free_DDIM_Mnist.ipynb
7.3 知乎讲解
- https://zhuanlan.zhihu.com/p/666552214
- https://zhuanlan.zhihu.com/p/656757576
7.4 原论文
- https://arxiv.org/abs/2006.11239, Denoising Diffusion Probabilistic Models, 2020
- https://arxiv.org/abs/2102.09672, Improved Denoising Diffusion Probabilistic Models, 2021
- https://arxiv.org/abs/2010.02502, Denoising Diffusion Implicit Models, 2022
- https://arxiv.org/abs/2207.12598, Classifier-Free Diffusion Guidance