扩散模型系列笔记(一)——DDPM

图1 扩散模型的直观理解

直观理解

扩散模型分为前向过程(扩散过程,Data → \to Noise)和后向过程(生成过程或逆扩散过程,Noise → \to Data)。在前向过程中,对于每一个观测样本,不断向样本中添加少量噪声,直到该样本完全被摧毁成为完全随机的噪声,该噪声属于正态分布;从正态分布中随机采样一个噪声样本,不断的对该样本去除少量的噪声,直到该噪声样本成为像真实样本的生成样本。前向过程的噪声是事先设置好的,该噪声还被用于神经网络训练,然后使用神经网络生成后向过程的噪声。
图2 扩散模型的有向图模型

生成模型

给定从一个感兴趣的分布中观察到的样本 x x x(比如图像数据集),生成模型的目标是学习建模其真实的数据分布 p ( x ) p(x) p(x)。一旦学习到,我们就可以从我们的近似模型中生成新的样本。
我们可以认为我们观察到的数据 x x x是由相关的未见的潜在变量表示或生成的,我们可以用随机变量 z z z来表示这种潜在变量。表达这一思想的最佳直觉来源于柏拉图的《洞穴寓言》。在这个寓言中,一群人终生被锁链困在洞穴中,他们只能看到面前墙上投射的二维影子,而这些影子是由经过火前的未见的三维物体生成的。对于这些人来说,他们所观察到的一切实际上都是由他们永远无法看到的高维抽象概念决定的。
类似地,我们在现实世界中遇到的物体也可能是一些更高层次表征的函数生成的;例如,这些表征可能包含抽象属性,如颜色、大小、形状等。然后,我们所观察到的可以被解释为这些抽象概念的三维投影或实例化,就像洞穴中的人们观察到的实际上是三维物体的二维投影一样。虽然洞穴中的人们永远无法看到(甚至完全理解)隐藏的物体,但他们仍然可以推理并得出关于这些物体的结论;同样地,我们也可以近似描述我们所观察到的数据的潜在表征。
尽管柏拉图的洞穴寓言阐明了潜在变量作为决定观测结果的潜在不可观察表征的概念,但这个类比的一个警告是,在生成建模中,我们通常寻求学习低维潜在表征,而不是高维表征。这是因为在没有强先验知识的情况下,尝试学习比观测值更高维度的表征是一项徒劳的努力。另一方面,学习低维潜在表征也可以视为一种压缩形式,并且有可能揭示描述观测结果的语义上有意义的结构。
在数学上,我们可以想象潜在的变量和我们观察到的数据是由一个联合分布 p ( x , z ) p(x,z) p(x,z)建模的。这个过程可以强调数据生成过程 p ( x , z ) = p ( z ) p ( x ∣ z ) p(x,z)=p(z)p(x|z) p(x,z)=p(z)p(xz)。即从潜在变量分布中 p ( z ) p(z) p(z)采样潜在变量 z z z p ( z ) p(z) p(z)是一个数学上明确定义的随机变量,比如高斯分布、均匀分布等;对于潜在变量 z z z,根据条件分布 p ( x ∣ z ) p(x|z) p(xz)生成观测样本,在变分自编码器中,条件分布 p ( x ∣ z ) p(x∣z) p(xz)通常由一个神经网络(解码器)进行参数化,在生成对抗模型中,条件分布 p ( x ∣ z ) p(x∣z) p(xz)由生成器 G ( z ) G(z) G(z) 定义。
目前生成模型可以根据他们概率分布的方式分为两类,显式生成模型和隐式生成模型。

显式生成模型

通过(近似)最大似然直接学习分布的概率密度(或质量)函数(即学习一个为观察到的数据样本分配高可能性的模型),所以显式生成模型又被称为基于似然的模型。典型的显式生成模型包括自回归模型、标准化流模型、能量基模型、高斯混合模型、自回归模型、变分自编码器、隐马尔可夫模型等 ,本文的扩散模型也属于这一类。
图3 贝叶斯网络、马尔可夫随机场 (MRF)、自回归模型和正则化流模型都是基于似然的模型示例。所有这些模型都表示分布的概率密度或质量函数。

隐式生成模型

概率分布由其采样过程的模型隐式表示。最突出的例子是生成对抗网络(GAN),通过使用神经网络变换随机高斯向量来合成来自数据分布的新样本。
图4 生成对抗网络

数学理解

首先明确一点: x 0 x_0 x0是可观测的干净数据,我们有其样本(比如我们数据集中的图片、音频等),但是我们不知道其真实分布 p ( x 0 ) p(x_0) p(x0) x T x_T xT是纯噪声数据,其先验分布为标准正态分布 N ( x T ; 0 , I ) \mathcal{N}(x_T;\mathrm{0},\mathrm{I}) N(xT;0,I) x 1 , x 2 , . . . , x T − 1 x_1,x_2,...,x_{T-1} x1,x2,...,xT1为带噪声的数据。 x 1 , x 2 , . . . , x T x_{1},x_{2},...,x_{T} x1,x2,...,xT都是潜在变量。如图2所示。注意 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt) p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)相同。

扩散过程

扩散模型在正向扩散过程中,通过逐渐向数据样本中加入噪声,使其最终转化为近似标准正态分布的样本。这个过程通常被描述为一个马尔可夫链。对于扩散过程我们可扩散模型在正向扩散过程中,通过逐渐向数据样本中加入噪声,使其最终转化为近似标准正态分布的样本。这个过程通常被描述为一个马尔可夫链以表示为联合分布 q ( x 1 : T ∣ x 0 ) q(x_{1:T}|x_0) q(x1:Tx0),如公式(1)所示,表示了给定观测数据 x 0 x_0 x0,生成潜在变量 x 1 , x 2 , . . . , x T x_1,x_2,...,x_T x1,x2,...,xT的过程。其中潜在变量 x t ∼ q ( x t ∣ x t − 1 ) x_t \sim q(x_t|x_{t-1}) xtq(xtxt1)如公式(2)所示,也被称为转移核, β 1 , β 2 , . . . , β T \beta_1,\beta_2, ..., \beta_T β1,β2,...,βT是噪声方差表,是人工设置的固定值。所以知道了噪声数据 x t − 1 x_{t-1} xt1和方差表,公式2的分布就是已知的,所以就能通过采样得到噪声数据 x t x_t xt。这里 x t x_t xt的获取过程与我们直观理解“向样本中加入噪声”似乎不太一样。但是根据重采样技巧,我们可以继续将公式(2)写作 x t = 1 − β t x t − 1 + β t ϵ t x_t = \sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}\epsilon_t xt=1βt xt1+βt ϵt,其中 ϵ t ∼ N ( 0 , I ) \epsilon_t \sim \mathcal{N}(\mathrm{0},\mathrm{I}) ϵtN(0,I)。这样就可以很好的理解加噪步骤,首先从标准正态分布中采样随机噪声 ϵ t \epsilon_t ϵt,然后对于样本 x t − 1 x_{t-1} xt1加入噪声样本 ( 1 − β t − 1 ) x t − 1 + β t ϵ t (\sqrt{1-\beta_t}-1)x_{t-1}+\sqrt{\beta_t}\epsilon_t (1βt 1)xt1+βt ϵt得到样本 x t x_t xt。这与我们直观理解的加噪过程是符合的,然而这仍然不是最终代码实现的方式。可以进一步将 x t x_t xt 表示为初始数据样本 x 0 x_0 x0 和一系列独立同分布的高斯噪声之和的形式。具体如公式(3)或(4)。其中: α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t = \prod_{i=1}^t \alpha_i αˉt=i=1tαi是累积乘积、 ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)是服从标准正态分布的高斯噪声。这个公式表示了在第 t t t 步直接从初始数据样本 x 0 x_0 x0 生成 x t x_t xt的过程。推导过程如下图所示。
q ( x 1 : T ∣ x 0 ) = ∏ i = 1 T q ( x t ∣ x t − 1 ) q(x_{1:T}|x_0)={\textstyle \prod_{i=1}^{T}}q(x_t|x_{t-1}) q(x1:Tx0)=i=1Tq(xtxt1) (1)
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1})=\mathcal{N}(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t\mathrm{I}) q(xtxt1)=N(xt;1βt xt1,βtI) (2)
x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon xt=αˉt x0+1αˉt ϵ (3)
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t|x_0)=\mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)\mathrm {I}) q(xtx0)=N(xt;αˉt x0,(1αˉt)I) (4)
image.png

生成过程

对于生成过程我们可以表示为联合分布 p ( x 0 : T ) p(x_{0:T}) p(x0:T)。如公式(5)所示, p ( x T ) = N ( x T ; 0 , I ) p(x_T)=\mathcal{N}(x_T;\mathrm{0},\mathrm{I}) p(xT)=N(xT;0,I),从标准正态分布不断生成 x T , x T − 1 , . . . , x t , . . , x 2 , x 1 , x 0 x_T,x_{T-1},...,x_t,..,x_2,x_1,x_0 xT,xT1,...,xt,..,x2,x1,x0的过程,首先从 p ( x T ) p({x_T}) p(xT)采样一个随机噪声 X T X_T XT,然后生成的样本和中间样本如公式(6)所示, θ \theta θ表示神经网络的参数, x t x_t xt t t t是神经网络的输入, μ \mu μ Σ \Sigma Σ是模型的输出。但是我们并不知道 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),又应该如何训练神经网络呢?这可以通过目标函数的推导得到。
p ( x 0 : T ) = p ( x T ) ∏ t = 1 T p ( x t − 1 ∣ x t ) p(x_{0:T})= p(x_T){\textstyle \prod_{t=1}^{T}}p(x_t-1|x_t) p(x0:T)=p(xT)t=1Tp(xt1∣xt) (5)
p ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\mu_\theta(x_t,t),\Sigma_\theta(x_t,t)) p(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t)) (6)

目标函数

扩散模型也是通过(近似)最大似然直接学习分布的概率密度(或质量)函数,所以扩散模型是显式生成模型。
如图5所示, log ⁡ p ( x ) \log p(x) logp(x)推导过程如下,可以分解为三个子项,所以最大化 log ⁡ p ( x ) \log p(x) logp(x)等价于对三个子项的处理。。第四行采用的詹森不等式变换,并且右侧式子就是ELBO(证据下界)。

  1. reconstruction term: E q ( x 1 ∣ x 0 ) [ log ⁡ p θ ( x 0 ∣ x 1 ) ] ≈ ∑ i = 1 n log ⁡ p θ ( x 0 i ∣ x 1 i ) \mathbb{E}_{q(x_1|x_0)}[\log p_\theta(x_0|x_1)] \approx \sum_{i=1}^{n}\log p_\theta(x_0^i|x_1^i) Eq(x1x0)[logpθ(x0x1)]i=1nlogpθ(x0ix1i)。右侧为采样的期望,n表示样本量。 p θ ( x 0 ∣ x 1 ) p_\theta(x_0|x_1) pθ(x0x1)为预测最终的生成样本 x 0 x_0 x0,所以被称为重建项,需要最大化该项。
  2. prior matching term: E q ( x T − 1 ∣ x 0 ) [ D K L ( q ( x T ∣ x T − 1 ) ∣ ∣ p ( x T ) ) ] \mathbb{E}_{q(x_{T-1}|x_0)}[D_{KL}(q(x_T|x_{T-1})||p(x_T))] Eq(xT1x0)[DKL(q(xTxT1)∣∣p(xT))]是先验匹配项。当最终的潜在分布与高斯先验相匹配时,它被最小化。这个项不需要优化,因为它没有可训练的参数;此外,由于我们假设一个足够大的 T T T,最终分布是高斯分布,这个项有效地变成零。
  3. consistency term: E q ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q ( x t ∣ x t − 1 ) ∣ ∣ p θ ( x t ∣ x t + 1 ) ) ] \mathbb{E}_{q(x_{t-1},x_{t+1}|x_0)}[D_{KL}(q(x_t|x_{t-1})||p_\theta(x_t|x_{t+1}))] Eq(xt1,xt+1x0)[DKL(q(xtxt1)∣∣pθ(xtxt+1))]是一致性项,需要最小化。它努力使得在 x t x_t xt处的分布与前向和后向过程都一致。也就是说,从一个噪声较大的图像去噪步骤应该与一个噪声较小的图像的加噪步骤相匹配,对于每一个中间时间步都是如此;这在数学上通过KL散度反映出来。当我们训练 p θ ( x t ∣ x t + 1 ) p_{\theta}(x_t|x_{t+1}) pθ(xtxt+1) 以匹配高斯分布 q ( x t ∣ x t − 1 ) q(x_t|x_{t−1}) q(xtxt1) 时,这一项被最小化,对应图5中绿色箭头和粉色箭头的最小化。可以通过蒙特卡罗估计和重参数化技巧后,再通过梯度下降进行优化,但是最终的代码也不是这样实现的。

图5  目标函数的推导

图6 扩散模型的有向图模型
在这个推导下,ELBO(Evidence Lower Bound)的所有项都是作为期望值计算的,因此可以使用蒙特卡洛估计来近似。然而,实际上使用我们刚刚推导出的项来优化ELBO可能是次优的;因为一致性项是作为每个时间步两个随机变量 x t − 1 , x t + 1 {x_{t−1}, x_{t+1}} xt1,xt+1的期望值计算的,其蒙特卡洛估计的方差可能比每个时间步只使用一个随机变量估计的项要高。由于它是通过求和 T − 1 T−1 T1个一致性项计算的,对于较大的 T T T值,ELBO的最终估计值可能会有高方差。
如图8所示,我们如果最小化绿色和粉色箭头,这样每个项都被计算一次只对一个随机变量的期望。首先我们进行一个变换,如公式(7)所示。
q ( x t ∣ x t − 1 , x 0 ) = q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) q(x_t|x_{t-1},x_0)=\frac{q(x_{t-1}|x_t,x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)} q(xtxt1,x0)=q(xt1x0)q(xt1xt,x0)q(xtx0) (7)
如图7所示, log ⁡ p ( x ) \log p(x) logp(x)另外一个推导过程如下,也可以分解为三个子项。
图7 目标函数的推导
因此,我们成功地推导出一种可以用较低方差估计ELBO的解释,因为每一项最多是作为一个随机变量的期望值计算的。这种形式也有一个优雅的解释,当我们检查每个单独的项时,这种解释就会显现出来:重构项和先验匹配项和前面介绍的意义,虽然先验匹配项公式不相同。
对于去噪匹配项(denoising matching term), E q ( x t ∣ x 0 ) [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) ] \mathbb{E}_{q(x_t|x_0)}[D_{KL}(q(x_{t-1}|x_t,x_0)||p_\theta(x_{t-1}|x_t))] Eq(xtx0)[DKL(q(xt1xt,x0)∣∣pθ(xt1xt))]。我们学习期望的去噪转移步骤 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1} | x_{t}) pθ(xt1xt),作为可处理的真实去噪转移步骤 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} | x_{t}, x_{0}) q(xt1xt,x0)的近似。这个 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} | x_{t}, x_{0}) q(xt1xt,x0) 转移步骤可以作为一个真实信号,因为它定义了在知道最终完全去噪的图像 x 0 x_{0} x0的情况下,如何对一个噪声图像 x t x_{t} xt进行去噪。因此,当两个去噪步骤的匹配尽可能接近时,这一项是最小化的,这通过它们的KL散度(KL Divergence)来衡量。

图8 扩散模型的有向图模型
q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} | x_{t}, x_{0}) q(xt1xt,x0)的具体数学表达可以由下图9的公式推理出来:
图9  公式推导
因此,我们已经证明,在每个步骤中, x t − 1 ∼ q ( x t − 1 ∣ x t , x 0 ) x_{t-1} \sim q(x_{t-1}|x_t, x_0) xt1q(xt1xt,x0) 是正态分布的,其均值 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0) x t x_t xt x 0 x_0 x0 的函数,而方差 Σ q ( t ) \Sigma_q(t) Σq(t) α \alpha α 系数的函数。这些 α \alpha α 系数在每个时间步是已知的并且固定;它们要么在模型中作为超参数永久设定,要么被视为当前神经网络推理输出来建模它们。我们可以将方差方程重写为 Σ q ( t ) = σ q 2 ( t ) I \Sigma_q(t) = \sigma_q^2(t)I Σq(t)=σq2(t)I,其中: σ q 2 ( t ) = ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t \sigma_q^2(t) = \frac{(1 - \alpha_t)(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} σq2(t)=1αˉt(1αt)(1αˉt1)
为了尽可能地将近似去噪转换步骤 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt) 与真实去噪转换步骤 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0) 匹配得更为接近,我们也可以将 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)建模为高斯分布。此外,由于所有的 α \alpha α 项在每个时间步都是已知的固定值,我们可以立即构建近似去噪转换步骤的方差,使其也为 Σ q ( t ) = σ q 2 ( t ) I \Sigma_q(t) = \sigma_q^2(t)I Σq(t)=σq2(t)I。然而,由于 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt) 不以 x 0 x_0 x0 为条件,我们必须将它的均值 μ θ ( x t , t ) \mu_\theta(x_t, t) μθ(xt,t) 参数化为 x t x_t xt 的函数。
然后如图10公式推导,可以得到一个时间步的损失函数。
图10 公式推导
我们继续参数化。根据公式3变换可得公式7如下:
x 0 = x t − 1 − α t ˉ ϵ 0 α ˉ t x_0=\frac{x_t-\sqrt{1-\bar{\alpha_t}}\epsilon_0}{\sqrt{\bar{\alpha}_t}} x0=αˉt xt1αtˉ ϵ0 (7)

我们可以通过公式(7)消去 μ q \mu_q μq x 0 x_0 x0,公式推导如图11所示:
图11 公式推导
因此,我们可以将我们的近似去噪转移均值 µ θ ( x t , t ) µ_θ(x_t, t) µθ(xt,t) 设置为:
μ θ ( x t , t ) = 1 α t x t − 1 − α t 1 − α ˉ t α t ϵ ^ θ ( x t , t ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha}_t}x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t} }\hat{\epsilon}_\theta(x_t,t) μθ(xt,t)=α t1xt1αˉt αt 1αtϵ^θ(xt,t) (8)
相应的优化问题变为:
图12 公式推导
所以,我们最终的损失函数为:
L t − 1 = arg ⁡ min ⁡ θ E x 0 , ϵ 1 2 σ q 2 ( t ) ( 1 − α t ) 2 ( 1 − α ˉ t ) α t [ ∣ ∣ ϵ 0 − ϵ ^ θ ( x t , t ) ∣ ∣ 2 2 ] , t ∼ U ( 2 , T ) L_{t-1} = \arg \min_\theta \mathbb{E}_{x_0,\epsilon}\frac{1}{2\sigma^2_q(t)}\frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t}[||\epsilon_0-\hat{\epsilon}_\theta(x_t,t)||^2_2],t \sim U(2,T) Lt1=argminθEx0,ϵ2σq2(t)1(1αˉt)αt(1αt)2[∣∣ϵ0ϵ^θ(xt,t)22],tU(2,T) (9)。
对于重建项 L 0 = arg ⁡ min ⁡ θ − E q log ⁡ p θ ( x 0 ∣ x 1 ) L_0 = \arg \min_\theta -\mathbb{E}_q \log p_\theta(x_0|x_1) L0=argminθEqlogpθ(x0x1)。在DDPM的原论文中,将这个生成过程的最后一项设置为了一个独立的离散编码器,由高斯分布 N ( x 0 ; μ θ ( x 1 , 1 ) , σ 2 I ) \mathcal{N}(x_0;\mu_\theta(x_1,1),\sigma^2 \mathrm{I}) N(x0;μθ(x1,1),σ2I)构成。具体而言,它将连续的高斯分布通过积分变换成离散的概率分布,用于计算给定条件 x 1 x_1 x1 下观测值 x 0 x_0 x0 的概率密度。展开 p ( x 0 ∣ x 1 ) p(x_0|x_1) p(x0x1)如公式10所示,忽略常数和权重,我们只关注 ∣ ∣ x 0 − μ θ ( x 1 , 1 ) ∣ ∣ 2 ||x_0-\mu_\theta(x_1,1)||^2 ∣∣x0μθ(x1,1)2
− log ⁡ p ( x 0 ∣ x 1 ) = − log ⁡ 1 2 π σ 1 2 + ∣ ∣ x 0 − μ θ ( x 1 , 1 ) ∣ ∣ 2 2 σ 1 2 -\log p(x_0|x_1)=-\log \frac{1}{\sqrt{2\pi} \sigma_1^2} + \frac{||x_0-\mu_\theta(x_1,1)||^2}{2\sigma_1^2} logp(x0x1)=log2π σ121+2σ12∣∣x0μθ(x1,1)2 (10)
和之前的思路一样,消去 x 0 x_0 x0和重参数化,所以根据公式(7)得到 x 0 = x 1 − 1 − α 1 ϵ 0 α 1 x_0 = \frac{x_1-\sqrt{1-\alpha_1}\epsilon_0}{\sqrt{\alpha_1}} x0=α1 x11α1 ϵ0,根据公式(8)可以得到 μ θ ( x 1 , 1 ) = 1 α 1 x 1 − 1 − α 1 1 − α 1 α 1 ϵ ^ θ ( x 1 , 1 ) \mu_\theta(x_1,1)=\frac{1}{\sqrt{\alpha}_1}x_1-\frac{1-\alpha_1}{\sqrt{1-\alpha}_1\sqrt{\alpha_1}}\hat{\epsilon}_\theta(x_1,1) μθ(x1,1)=α 11x11α 1α1 1α1ϵ^θ(x1,1),带入 ∣ ∣ x 0 − μ θ ( x 1 , 1 ) ∣ ∣ 2 ||x_0-\mu_\theta(x_1,1)||^2 ∣∣x0μθ(x1,1)2得到:
∣ ∣ x 0 − μ θ ( x 1 , 1 ) ∣ ∣ 2 = 1 − α 1 α 1 ∣ ∣ ϵ 0 − ϵ ^ θ ( x 1 , 1 ) ∣ ∣ … … 2 ||x_0-\mu_\theta(x_1,1)||^2 = \frac{\sqrt{1-\alpha_1}}{\sqrt{\alpha_1}}||\epsilon_0-\hat{\epsilon}_\theta(x_1,1)||……2 ∣∣x0μθ(x1,1)2=α1 1α1 ∣∣ϵ0ϵ^θ(x1,1)∣∣……2 (11)
忽略公式(9)和公式(11)的权重系数,我们可以得到最终化简的损失函数如公式(12)所示,且t取值为1到T:
L s i m p l e ( θ ) = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ] = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ] L_{simple}(\theta) = \mathbb{E}_{t,x_0,\epsilon}[||\epsilon-\epsilon_\theta(\sqrt{x_t},t )||^2] = \mathbb{E}_{t,x_0,\epsilon}[||\epsilon-\epsilon_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon,t )||^2] Lsimple(θ)=Et,x0,ϵ[∣∣ϵϵθ(xt ,t)2]=Et,x0,ϵ[∣∣ϵϵθ(αˉt x0+1αˉt ϵ,t)2] (12)
这里, ϵ ^ θ ( x t , t ) \hat{\epsilon}_\theta(x_t,t) ϵ^θ(xt,t)是一个神经网络,它学习预测源噪声 ϵ 0 ∼ N ( ϵ ; 0 , I ) \epsilon_0 \sim \mathcal{N}(\epsilon;0,I) ϵ0N(ϵ;0,I),该噪声决定了从 x 0 x_0 x0 x t x_t xt 的转变。这与我们直接上的理解是相同的。

代码编程

DDPM原论文伪代码如图13所示,分为训练和采样两个过程。需要注意的是采样阶段的第四行,对于 x t x_t xt去噪得到 x t − 1 x_{t-1} xt1时加入了一个随机噪声 σ t z \sigma_t z σtz。这样的目的是加入噪声扰动,是生成的数据更具多样性。
图13 伪代码
在所有实验中,DDPM原论文将T设为1000。将前向过程的方差设置为常数,从 β 1 = 1 0 − 4 β_1 = 10^{−4} β1=104线性增加到 β T = 0.02 β_T = 0.02 βT=0.02。选择这些常数时相对于缩放到 [ − 1 , 1 ] [−1, 1] [1,1]范围的数据保持较小,以确保逆过程和前向过程具有大致相同的分布,同时使得在 x T x_T xT处的信噪比尽可能小(在DDPM的实验中, L T = D K L ( q ( x T ∣ x 0 ) ∣ ∣ N ( 0 , I ) ≈ 1 0 − 5 ) L_T=D_{KL}(q(x_T|x_0)||\mathcal{N}(0,I)\approx 10^{-5}) LT=DKL(q(xTx0)∣∣N(0,I)105))。神经网络适用U-net模型。

这里给出labmlai实现的代码:

from typing import Tuple, Optionalimport torch
import torch.nn.functional as F
import torch.utils.data
from torch import nnfrom labml_nn.diffusion.ddpm.utils import gatherclass DenoiseDiffusion:def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):super().__init__()self.eps_model = eps_model #预测噪声的神经网络模型,比如u-net模型 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device) # n_steps,为总的时间步T,创建噪声betaself.alpha = 1. - self.beta  #beta为添加的噪声的方差self.alpha_bar = torch.cumprod(self.alpha, dim=0)self.n_steps = n_stepsself.sigma2 = self.betadef q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: #获取q(x_t|x_0)分布的均值和方差mean = gather(self.alpha_bar, t) ** 0.5 * x0var = 1 - gather(self.alpha_bar, t)return mean, vardef q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):#q(x_t|x_0)分布的采样,使用重采样技巧if eps is None:eps = torch.randn_like(x0) #从标准正态分布采样一个样本mean, var = self.q_xt_x0(x0, t) #获取q(x_t|x_0)分布的均值和方差return mean + (var ** 0.5) * eps #返回从q(x_t|x_0)分布采样的样本def p_sample(self, xt: torch.Tensor, t: torch.Tensor):#从p(x_{t-1}|x_t)的分布中采样,生成过程eps_theta = self.eps_model(xt, t) #使用神经网络预测噪声alpha_bar = gather(self.alpha_bar, t)alpha = gather(self.alpha, t)eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta) #采样伪代码中第四行右侧前半部分var = gather(self.sigma2, t) #采样伪代码中第四行右侧后半部分eps = torch.randn(xt.shape, device=xt.device)  return mean + (var ** .5) * eps #采样伪代码中第四行def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):batch_size = x0.shape[0]t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)#伪代码1的第三行if noise is None:noise = torch.randn_like(x0) # 从标准正态分布提取样本,用于重参数化xt = self.q_sample(x0, t, eps=noise) # 返回从q(x_t|x_0)分布采样的样本eps_theta = self.eps_model(xt, t) # 伪代码1的第5行return F.mse_loss(noise, eps_theta)

:::info

参考文献

  1. Denoising Diffusion Probabilistic Models
  2. Understanding Diffusion Models: A Unified Perspective
  3. Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song (yang-song.net)
  4. What are Diffusion Models? | Lil’Log (lilianweng.github.io)
  5. https://nn.labml.ai/diffusion/ddpm/index.html
  6. https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master
    :::

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/394136.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode每日刷题之字符串中的第一个唯一字符(C++)

在学习的过程中对代码的熟练运用至关重要,练习解决实际问题就可以很好的锻炼自己的编程能力,接下来让我们练习这道 387.字符串中的第一个唯一字符 思路解析 根据题意我们可以知道这个字符串只有小写字母,并且可能包含多个唯一字符&#xff0…

Java---字符串string练习

目录: 1.将数字转换成罗马数字 2.键盘输入任意字符串,打乱里面的内容 3.返回字符串中最后一个单词长度 4.调整A字符串 看是否可与B字符串匹配 一: //键盘录入一个字符串// 长度小于等于9 只能是数字// -将内容变成罗马数字// Ⅰ Ⅱ Ⅲ Ⅳ…

智慧水务项目(二)django(drf)+angular 18 创建通用model,并对orm常用字段进行说明

一、说明 上一篇文章建立一个最简单的项目,现在我们建立一个公共模型,抽取公共字段,以便于后续模块继承,过程之中会对orm常用字段进行说明,用到的介绍一下 二、创建一个db.py 目录如下图 1、代码 from importlib im…

基于QT实现的简易WPS(已开源)

一、开发工具及开源地址: 开发工具:QTCreator ,QT 5 开源地址: GitHub - Whale-xh/WPS_official: Simple WPS based on QTSimple WPS based on QT. Contribute to Whale-xh/WPS_official development by creating an acc…

推荐 3个实用且完全免费的在线工具,每天都会用到,无需登录打开即用

100font 100font是一个专业的免费商用字体下载网站,专注于收集、整理和分享各种免费无版权的商用字体。用户可以在这个平台上找到并下载简体中文、繁体中文、英文、日文、韩文等多种语言类型的字体。 该网站的特点包括清晰的分类和直观的下载流程,用户可…

进阶SpringBoot之 Spring 官网或 IDEA 快速构建项目

SpringBoot 就是一个 JavaWeb 的开发框架,约定大于配置 程序 数据结构 算法 微服务架构是把每个功能元素独立出来,再动态组合,是对功能元素的复制 这样做可以节省调用资源,每个功能元素的服务都是一个可替代、可独立升级的软…

算法混合杂项

基础类型 可用template 投影 是有方向的 求俩直线交点 推公式 q我们不知道,已知p1 p2,正弦定理,α可以用叉积表示出来 β同理 所以我们能求出p1q 已知piq 回归到我们上一个问题,已知方向和长度,我们就能够求出Voq …

C语言 ——— 学习并使用字符分类函数

目录 学习isupper函数 学习isdigit函数 学习tolower函数 将输入的字符串中把大写字母转换为小写字母并输出 学习isupper函数 参数部分: 形参需要传递的是一个字母,字符在ASCII码表上是以整型存储的,所以实参部分用(int c)没有问题 返回…

【iOS】AutoreleasePool自动释放池的实现原理

目录 ARC与MRC项目中的main函数自动释放池autoreleasepool {}实现原理AutoreleasePoolPage总结 objc_autoreleasePoolPush的源码分析autoreleaseNewPageautoreleaseFullPageautoreleaseNoPage autoreleaseFast总结 autorelease方法源码分析objc_autoreleasePoolPop的源码分析po…

谁来做引领企业精益变革的舵手最合适?

在这个瞬息万变的商业时代,企业如同航行在波涛汹涌的大海中的巨轮,既需面对未知的挑战,也要抓住稍纵即逝的机遇。而在这场没有终点的航行中,引领企业实现精益变革的舵手,无疑是推动企业破浪前行、稳健致远的关键角色。…

FFmpeg Windows安装教程

一. 下载ffmpeg 进入Download FFmpeg网址,点击下载windows版ffmpeg。 下载第一个essentials版本就行。 二. 环境配置 上面源码解压后如下 将bin添加到系统环境变量 验证安装是否成功,输入ffmpeg –version,显示版本即为安装成功。

Python学习(1):使用Python的Dask库实现并行计算

目录 一、Dask介绍 二、使用说明 安装 三、测试 1、单个文件中实现功能 2、运行多个可执行文件 最近在写并行计算相关部分,用到了python的Dask库。 Dask官网:Dask | Scale the Python tools you love 一、Dask介绍 Dask是一个灵活的并行和分布式…

网工内推 | 国企运维工程师,华为认证优先,最高年薪20w

01 上海陆家嘴物业管理有限公司 🔷招聘岗位:IT运维工程师 🔷岗位职责: 1、负责对公司软、硬件系统、周边设备、桌面系统、服务器、网络基础环境运行维护、故障排除。 2、负责对各部门软件操作、网络安全进行检查、指导。 3、负责…

Mysql——update更新数据的方式

注:文章参考: MySQL 更新数据 不同条件(批量)更新不同值_update批量更新同一列不同值-CSDN博客文章浏览阅读2w次,点赞20次,收藏70次。一般在更新时会遇到以下场景:1.全部更新;2.根据条件更新字段中的某部分…

vivado OPT_SKIPPED

当跳过候选基元单元的逻辑优化时,OPT_skipped属性 更新单元格以反映跳过的优化。当跳过多个优化时 在同一单元格上,OPT_SKIPPED值包含跳过的优化列表。 架构支持 所有架构。 适用对象 OPT_SKIPPED属性放置在单元格上。 价值观 下表列出了各种OPT_design选…

【CSDN平台BUG】markdown图片链接格式被手机端编辑器自动破坏(8.6 已修复)

文章目录 bug以及解决方法bug原理锐评后续 bug以及解决方法 现在是2024年8月,我打开csdn手机编辑器打算修改一下2023年12月的一篇文章,结果一进入编辑器,源码就变成了下面这个样子,我起初不以为意,就点击了发布&#…

Revit二次开发选择过滤器,SelectionFilter

过滤器分为选择过滤器与规则过滤器 规则过滤器可以看我之前写的这一篇文章: Revit二次开发在项目中给链接模型附加过滤器 选择过滤器顾名思义就是可以将选择的构件ID集合传入并加入到视图过滤器中,有一些场景需要对某些构件进行过滤选择,但是没有共同的逻辑规则进行筛选的情况…

Golang | Leetcode Golang题解之第313题超级丑数

题目&#xff1a; 题解&#xff1a; func nthSuperUglyNumber(n int, primes []int) int {dp : make([]int, n1)m : len(primes)pointers : make([]int, m)nums : make([]int, m)for i : range nums {nums[i] 1}for i : 1; i < n; i {minNum : math.MaxInt64for j : range…

力扣面试150 基本计算器 双栈模拟

Problem: 224. 基本计算器 &#x1f468;‍&#x1f3eb; 参考题解 Code class Solution {public int calculate(String s) {// 存放所有的数字&#xff0c;用于计算LinkedList<Integer> nums new LinkedList<>();// 为了防止第一个数为负数&#xff0c;先往 nu…

开源免费的wiki知识库

开源的Wiki知识库有多种选择&#xff0c;它们各自具有不同的特点和优势&#xff0c;适用于不同的场景和需求。以下是一些主流的开源Wiki知识库系统&#xff1a; MediaWiki 简介&#xff1a;MediaWiki是使用PHP编写的免费开源Wiki软件包&#xff0c;是Wikipedia和其他Wikimedia…