文章目录
- 前言
- 问题建模
- 条件扩散模型的前向过程
- 条件扩散模型的反向过程
- 条件扩散模型的训练目标
前言
几乎所有的生成式模型,发展到后期都需要引入"控制"的概念,可控制的生成式模型才能更好应用于实际场景。本文将总结《Diffusion Models Beat GANs on Image Synthesis》中提出的Classifier Guidance Diffusion(即条件扩散模型),其往Diffusion Model中引入了控制的概念,可以控制DDPM、DDIM生成指定类别(条件)的图片。
问题建模
本章节所有符号定义与DDPM一致,在条件 y y y下的Diffusion Model的前向与反向过程可以定义为
q ^ ( x t + 1 ∣ x t , y ) q ^ ( x t ∣ x t + 1 , y ) \begin{aligned} \hat q(x_{t+1}|x_{t},y)\\ \hat q(x_t|x_{t+1},y) \end{aligned} q^(xt+1∣xt,y)q^(xt∣xt+1,y)
只要求出上述两个概率密度函数,我们即可按条件生成图像。
我们利用 q ^ \hat q q^表示条件扩散模型的概率密度函数, q q q表示扩散模型的概率密度函数。
条件扩散模型的前向过程
对于前向过程,作者定义了以下等式
q ^ ( x 0 ) = q ( x 0 ) q ^ ( x t + 1 ∣ x t , y ) = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) \begin{aligned} \hat q(x_0)&=q(x_0)\\ \hat q(x_{t+1}|x_t,y)&=q(x_{t+1}|x_t)\\ \hat q(x_{1:T}|x_0,y)&=\prod_{t=1}^T\hat q(x_t|x_{t-1},y) \end{aligned} q^(x0)q^(xt+1∣xt,y)q^(x1:T∣x0,y)=q(x0)=q(xt+1∣xt)=t=1∏Tq^(xt∣xt−1,y)
基于上述第二行定义,可知基于条件 y y y的diffusion model的前向过程与普通的diffusion model一致,即 q ^ ( x t + 1 ∣ x t ) = q ( x t + 1 ∣ x t ) \hat q(x_{t+1}|x_t)=q(x_{t+1}|x_t) q^(xt+1∣xt)=q(xt+1∣xt)。即加噪过程与条件 y y y无关,这种定义也是合理的。
条件扩散模型的反向过程
对于反向过程,我们有
q ^ ( x t ∣ x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( y ∣ x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t , y ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) (1.0) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(x_{t+1},y)}\\ &=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(y|x_{t+1})\hat q(x_{t+1})}\\ &=\frac{\hat q(x_t,y|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned}\tag{1.0} q^(xt∣xt+1,y)=q^(xt+1,y)q^(xt,xt+1,y)=q^(y∣xt+1)q^(xt+1)q^(xt,xt+1,y)=q^(y∣xt+1)q^(xt,y∣xt+1)=q^(y∣xt+1)q^(y∣xt,xt+1)q^(xt∣xt+1)(1.0)
已知条件扩散模型的前向过程与扩散模型一致,则有
q ^ ( x 1 : T ∣ x 0 ) = q ( x 1 : T ∣ x 0 ) \hat q(x_{1:T}|x_0)=q(x_{1:T}|x_0) q^(x1:T∣x0)=q(x1:T∣x0)
进而有
q ^ ( x t ) = ∫ q ^ ( x 0 , . . . , x t ) d x 0 : t − 1 = ∫ q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = q ( x t ) \begin{aligned} \hat q(x_{t})&=\int \hat q(x_0,...,x_t) dx_{0:t-1}\\ &=\int \hat q(x_0)\hat q(x_{1:t}|x_0)dx_{0:t-1}\\ &=\int q(x_0)q(x_{1:t}|x_0)dx_{0:t-1}\\ &=q(x_t) \end{aligned} q^(xt)=∫q^(x0,...,xt)dx0:t−1=∫q^(x0)q^(x1:t∣x0)dx0:t−1=∫q(x0)q(x1:t∣x0)dx0:t−1=q(xt)
对于 q ^ ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1}) q^(xt∣xt+1),则有
q ^ ( x t ∣ x t + 1 ) = q ^ ( x t , x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t + 1 ∣ x t ) q ^ ( x t ) q ^ ( x t + 1 ) = q ( x t + 1 ∣ x t ) q ( x t ) q ( x t + 1 ) = q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1})&=\frac{\hat q(x_t,x_{t+1})}{\hat q(x_{t+1})}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(x_{t})}{\hat q(x_{t+1})}\\ &=\frac{q(x_{t+1}|x_t)q(x_{t})}{q(x_{t+1})}\\ &=q(x_{t}|x_{t+1}) \end{aligned} q^(xt∣xt+1)=q^(xt+1)q^(xt,xt+1)=q^(xt+1)q^(xt+1∣xt)q^(xt)=q(xt+1)q(xt+1∣xt)q(xt)=q(xt∣xt+1)
对于 q ^ ( y ∣ x t , x x t + 1 ) \hat q(y|x_t,x_{x_{t+1}}) q^(y∣xt,xxt+1),我们有
q ^ ( y ∣ x t , x x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) \begin{aligned} \hat q(y|x_t,x_{x_{t+1}})&=\frac{\hat q(x_{t+1}|x_t,y)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\hat q(y|x_t) \end{aligned} q^(y∣xt,xxt+1)=q^(xt+1∣xt)q^(xt+1∣xt,y)q^(y∣xt)=q^(xt+1∣xt)q^(xt+1∣xt)q^(y∣xt)=q^(y∣xt)
因此式1.0为
q ^ ( x t ∣ x t + 1 , y ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t)q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned} q^(xt∣xt+1,y)=q^(y∣xt+1)q^(y∣xt,xt+1)q^(xt∣xt+1)=q^(y∣xt+1)q^(y∣xt)q(xt∣xt+1)
由于在反向过程中, x t + 1 x_{t+1} xt+1是已知的,因此 q ^ ( y ∣ x t + 1 ) \hat q(y|x_{t+1}) q^(y∣xt+1)也可看成已知值,设其倒数为 Z Z Z,则有
q ^ ( x t ∣ x t + 1 , y ) = Z q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y) = Z\hat q(y|x_t)q(x_{t}|x_{t+1}) \end{aligned} q^(xt∣xt+1,y)=Zq^(y∣xt)q(xt∣xt+1)
取log可得
log q ^ ( x t ∣ x t + 1 , y ) = log Z + log q ^ ( y ∣ x t ) + log q ^ ( x t ∣ x t + 1 ) (1.1) \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)=\log Z+\log \hat q(y|x_t)+\log \hat q(x_t|x_{t+1})\tag{1.1} \end{aligned} logq^(xt∣xt+1,y)=logZ+logq^(y∣xt)+logq^(xt∣xt+1)(1.1)
设 q ^ ( x t ∣ x t + 1 ) = N ( μ t , ∑ t 2 ) \hat q(x_t|x_{t+1})=\mathcal N(\mu_t,\sum_t^2) q^(xt∣xt+1)=N(μt,∑t2),则有
log q ^ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C (1.2) \log \hat q(x_{t}|x_{t+1})=-\frac{1}{2}(x_t-\mu_t)^T({\sum}_t)^{-1}(x_t-\mu_t)+C\tag{1.2} logq^(xt∣xt+1)=−21(xt−μt)T(∑t)−1(xt−μt)+C(1.2)
对于 log q ^ ( y ∣ x t ) \log \hat q(y|x_t) logq^(y∣xt),在 x t = μ t x_t=\mu_t xt=μt处做泰勒展开,则有
log q ^ ( y ∣ x t ) ≈ log q ^ ( y ∣ x t ) ∣ x t = μ t + ( x t − μ t ) ∇ x t log q ^ ( y ∣ x t ) ∣ x t = μ t = C 1 + ( x t − μ t ) g (1.3) \begin{aligned} \log \hat q(y|x_t) &\approx \log \hat q(y|x_t)|_{x_t=\mu_t}+(x_t-\mu_t)\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t}\\ &=C_1+(x_t-\mu_t)g \end{aligned}\tag{1.3} logq^(y∣xt)≈logq^(y∣xt)∣xt=μt+(xt−μt)∇xtlogq^(y∣xt)∣xt=μt=C1+(xt−μt)g(1.3)
其中 g = ∇ x t log q ^ ( y ∣ x t ) ∣ x t = μ t g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} g=∇xtlogq^(y∣xt)∣xt=μt,结合式1.1、1.2、1.3,有
log q ^ ( x t ∣ x t + 1 , y ) ≈ C 1 + ( x t − μ t ) g + log Z − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C = ( x t − μ t ) g − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C 2 = − 1 2 ( x t − μ t − ∑ t g ) T ( ∑ t ) − 1 ( x t − μ t − ∑ t g ) + C 3 \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)&\approx C_1+(x_t-\mu_t)g+\log Z-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C\\ &=(x_t-\mu_t)g-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C_2\\ &=-\frac{1}{2}(x_t-\mu_t-\sum{_t} g)^T(\sum{_t})^{-1}(x_t-\mu_t-\sum{_t}g)+C_3 \end{aligned} logq^(xt∣xt+1,y)≈C1+(xt−μt)g+logZ−21(xt−μt)T(∑t)−1(xt−μt)+C=(xt−μt)g−21(xt−μt)T(∑t)−1(xt−μt)+C2=−21(xt−μt−∑tg)T(∑t)−1(xt−μt−∑tg)+C3
最终有
q ^ ( x t ∣ x t + 1 , y ) ≈ N ( μ t + ∑ t g , ( ∑ t ) 2 ) g = ∇ x t log q ^ ( y ∣ x t ) ∣ x t = μ t (1.4) \begin{aligned} \hat q(x_t|x_{t+1},y)\approx \mathcal N(\mu_t+{\sum}_{t}g,({\sum}_t)^2)\\ g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} \end{aligned}\tag{1.4} q^(xt∣xt+1,y)≈N(μt+∑tg,(∑t)2)g=∇xtlogq^(y∣xt)∣xt=μt(1.4)
为了获得 ∇ x t log q ^ ( y ∣ x t ) \nabla_{x_t}\log\hat q(y|x_t) ∇xtlogq^(y∣xt),Classifier Guidance Diffusion在训练好的Diffusion model的基础上额外训练了一个分类头。
假设 x t ≈ μ t x_t \approx\mu_t xt≈μt,则Classifier Guidance Diffusion的反向过程为:
其中 p ϕ ( y ∣ x t ) = q ^ ( y ∣ x t ) p_ \phi(y|x_t)=\hat q(y|x_t) pϕ(y∣xt)=q^(y∣xt), s s s为一个超参数。
式1.4有个问题,当方差 ∑ \sum ∑取值为0时, ∑ ∇ x t log q ^ ( y ∣ x t ) {\sum}\nabla_{x_t}\log\hat q(y|x_t) ∑∇xtlogq^(y∣xt)取值将为0,无法控制生成指定条件的图像。因此式1.4不适用于DDIM等确定性采样的扩散模型。
在推导DDIM的采样公式前,我们先了解一下用Tweedie方法做参数估计的流程。
Tweedie方法主要用于指数族概率分布的参数估计,而高斯分布属于指数族概率分布,自然也适用。假设有一批样本 z z z,则利用样本 z z z估计高斯分布 N ( Z ; μ , ∑ 2 ) \mathcal N(Z;\mu,{\sum}^2) N(Z;μ,∑2)的均值 μ \mu μ的公式为
E [ μ ∣ z ] = z + ∑ 2 ∇ z log p ( z ) (1.5) E[\mu|z]=z+{\sum}^2\nabla_z\log p(z)\tag{1.5} E[μ∣z]=z+∑2∇zlogp(z)(1.5)
已知DDPM、DDIM的前向过程有
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) (1.6) q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar \alpha_t}x_0,(1-\bar\alpha_t)\mathcal I)\tag{1.6} q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)(1.6)
结合式1.5、1.6可得
α ˉ t x 0 = x t + ( 1 − α ˉ t ) ∇ x t log p ( x t ) \begin{aligned} \sqrt{\bar \alpha_t}x_0=x_t+(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t) \end{aligned} αˉtx0=xt+(1−αˉt)∇xtlogp(xt)
进而有
x t = α ˉ t x 0 − ( 1 − α ˉ t ) ∇ x t log p ( x t ) (1.7) x_t=\sqrt{\bar \alpha_t}x_0-(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t)\tag{1.7} xt=αˉtx0−(1−αˉt)∇xtlogp(xt)(1.7)
设 ϵ t \epsilon_t ϵt服从标准正态分布,则从式1.6可知
x t = α ˉ t x 0 + 1 − α ˉ t ϵ t (1.8) x_t=\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon_t\tag{1.8} xt=αˉtx0+1−αˉtϵt(1.8)
结合式1.7、1.8,则有
∇ x t log p ( x t ) = − 1 1 − α ˉ t ϵ t (1.9) \nabla_{x_t}\log p(x_t)=-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t\tag{1.9} ∇xtlogp(xt)=−1−αˉt1ϵt(1.9)
已知DDIM的采样公式为
x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ϵ θ ( x t ) α ˉ t + 1 − α ˉ t − δ t 2 ϵ θ ( x t ) (2.0) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}\epsilon_\theta(x_t)}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}\epsilon_\theta(x_t)\tag{2.0} xt−1=αˉt−1αˉtxt−1−αˉtϵθ(xt)+1−αˉt−δt2ϵθ(xt)(2.0)
结合式1.9、2.0可将DDIM的采样公式转变为
x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ( − 1 − α ˉ t ∇ x t log p ( x t ) ) α ˉ t + 1 − α ˉ t − δ t 2 ( − 1 − α ˉ t ∇ x t log p ( x t ) ) (2.1) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))\tag{2.1} xt−1=αˉt−1αˉtxt−1−αˉt(−1−αˉt∇xtlogp(xt))+1−αˉt−δt2(−1−αˉt∇xtlogp(xt))(2.1)
我们只需要将其中的 ∇ x t log p ( x t ) \nabla_{x_t}\log p(x_t) ∇xtlogp(xt)替换为 ∇ x t log p ( x t ∣ y ) \nabla_{x_t}\log p(x_t|y) ∇xtlogp(xt∣y),即可引入条件 y y y来控制DDIM的生成过程,利用贝叶斯定理,我们有
log p ( x t ∣ y ) = log p ( y ∣ x t ) + log p ( x t ) − log p ( y ) ∇ x t log p ( x t ∣ y ) = ∇ x t log p ( y ∣ x t ) + ∇ x t log p ( x t ) − ∇ x t log p ( y ) = ∇ x t log p ( y ∣ x t ) + ∇ x t log p ( x t ) = ∇ x t log p ( y ∣ x t ) − 1 1 − α ˉ t ϵ t (2.2) \begin{aligned} \log p(x_t|y)&=\log p(y|x_t)+\log p(x_t)-\log p(y)\\ \nabla_{x_t}\log p(x_t|y)&=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)-\nabla_{x_t}\log p(y)\\ &=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)\\ &=\nabla_{x_t}\log p(y|x_t)-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t \end{aligned}\tag{2.2} logp(xt∣y)∇xtlogp(xt∣y)=logp(y∣xt)+logp(xt)−logp(y)=∇xtlogp(y∣xt)+∇xtlogp(xt)−∇xtlogp(y)=∇xtlogp(y∣xt)+∇xtlogp(xt)=∇xtlogp(y∣xt)−1−αˉt1ϵt(2.2)
则有
− 1 − α ˉ t ∇ x t log p ( x t ∣ y ) = ϵ t − 1 − α ˉ t ∇ x t log p ( y ∣ x t ) (2.3) -\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t|y)=\epsilon_t-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(y|x_t)\tag{2.3} −1−αˉt∇xtlogp(xt∣y)=ϵt−1−αˉt∇xtlogp(y∣xt)(2.3)
至此,我们可以得到DDIM的采样流程为
对于DDIM等确定性采样的扩散模型,其应在训练好的Diffusion model的基础上额外训练了一个分类头,从而转变为Classifier Guidance Diffusion。
条件扩散模型的训练目标
注意到 q ^ ( x t ∣ x t + 1 ) = q ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1})=q(x_t|x_{t+1}) q^(xt∣xt+1)=q(xt∣xt+1),并且上述的推导过程并没有改变 q ( x t ∣ x t + 1 ) 、 q ( x t + 1 ∣ x t ) q(x_t|x_{t+1})、q(x_{t+1}|x_t) q(xt∣xt+1)、q(xt+1∣xt)的形式,因此Classifier Guidance Diffusion的训练目标与DDPM、DDIM是一致的,都可以拟合训练数据。