论文阅读笔记:Denoising Diffusion Probabilistic Models (3)

论文阅读笔记:Denoising Diffusion Probabilistic Models (1)
论文阅读笔记:Denoising Diffusion Probabilistic Models (2)
论文阅读笔记:Denoising Diffusion Probabilistic Models (3)

4、损失函数逐项分析

可以看出 L L L总共分为了3项,首先考虑第一项 L 1 L_1 L1
L 1 = E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) ( l o g [ q ( x T ∣ x 0 ) p ( x T ) ] ) = ∫ d x 1 : T ⋅ q ( x 1 : T ∣ x 0 ) ⋅ l o g [ q ( x T ∣ x 0 ) p ( x T ) ] = ∫ d x 1 : T ⋅ q ( x 1 : T ∣ x 0 ) q ( x T ∣ x 0 ) ⋅ q ( x T ∣ x 0 ) ⋅ l o g [ q ( x T ∣ x 0 ) p ( x T ) ] = ∫ d x 1 : T ⋅ q ( x 1 : T − 1 ∣ x 0 , x T ) ⏟ q ( x 1 : T ∣ x 0 ) = q ( x T ∣ x 0 ) ⋅ q ( x 1 ; T − 1 ∣ x 0 , x T ) ⋅ q ( x T ∣ x 0 ) ⋅ l o g [ q ( x T ∣ x 0 ) p ( x T ) ] = ∫ ( ∫ q ( x 1 : T − 1 ∣ x 0 , x T ) ⋅ ∏ k = 1 T − 1 d x k ⏟ 二重积分化为两个定积分相乘,并且 = 1 ) ⋅ q ( x T ∣ x 0 ) ⋅ l o g [ q ( x T ∣ x 0 ) p ( x T ) ] ⋅ d x T = ∫ q ( x T ∣ x 0 ) ⋅ l o g [ q ( x T ∣ x 0 ) p ( x T ) ] ⋅ d x T = E x T ∼ q ( x T ∣ x 0 ) l o g [ q ( x T ∣ x 0 ) p ( x T ) ] = K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) \begin{equation} \begin{split} L_1&=E_{x_{1:T} \sim q(x_{1:T} | x_0)} \Bigg(log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)}\Big]\Bigg) \\ &=\int dx_{1:T} \cdot q(x_{1:T}| x_0) \cdot log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)}\Big] \\ &=\int dx_{1:T} \cdot \frac{q(x_{1:T}| x_0)}{q(x_T|x_0)} \cdot q(x_T|x_0) \cdot log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)}\Big] \\ &=\int dx_{1:T} \cdot \underbrace{ q(x_{1:T-1}| x_0, x_T) }_{q(x_{1:T}| x_0)=q(x_{T}|x_0) \cdot q(x_{1;T-1}| x_0, x_T)} \cdot q(x_T|x_0) \cdot log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)}\Big] \\ &=\int \Bigg( \underbrace{ \int q(x_{1:T-1}| x_0, x_T) \cdot \prod_{k=1}^{T-1} dx_k }_{二重积分化为两个定积分相乘,并且=1} \Bigg) \cdot q(x_T|x_0) \cdot log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)} \Big] \cdot dx_{T} \\ &=\int q(x_T|x_0) \cdot log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)} \Big] \cdot dx_{T} \\ &=E_{x^T\sim q(x_T|x_0)} log \Big[ \frac{q(x_{T}|x_0)}{ p(x_T)} \Big]\\ &= KL\Big(q(x_T|x_0)||p(x_T)\Big) \end{split} \end{equation} L1=Ex1:Tq(x1:Tx0)(log[p(xT)q(xTx0)])=dx1:Tq(x1:Tx0)log[p(xT)q(xTx0)]=dx1:Tq(xTx0)q(x1:Tx0)q(xTx0)log[p(xT)q(xTx0)]=dx1:Tq(x1:Tx0)=q(xTx0)q(x1;T1x0,xT) q(x1:T1x0,xT)q(xTx0)log[p(xT)q(xTx0)]=(二重积分化为两个定积分相乘,并且=1 q(x1:T1x0,xT)k=1T1dxk)q(xTx0)log[p(xT)q(xTx0)]dxT=q(xTx0)log[p(xT)q(xTx0)]dxT=ExTq(xTx0)log[p(xT)q(xTx0)]=KL(q(xTx0)∣∣p(xT))

可以看出, L 1 L_1 L1 q ( x T ∣ x 0 ) q(x_T|x_0) q(xTx0) p ( x T ) p(x_T) p(xT)的散度。 q ( x T ∣ x 0 ) q(x_T|x_0) q(xTx0)是前向加噪过程的终点,是无限趋向于标准正态分布。而 p ( x T ) p(x_T) p(xT)是高斯分布,这在论文《Denoising Diffusion Probabilistic Models》中的2 Background的第四行中有说明。由 两个高斯分布KL散度推导可以计算出 L 1 L_1 L1,也就是说 L 1 L_1 L1是一个定值。因此,在损失函数中 L 1 L_1 L1可以被忽略掉。

接着考虑第二项 L 2 L_2 L2

L 2 = E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) ( ∑ t = 2 T l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) ( l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T ( ∫ d x 1 : T ⋅ q ( x 1 : T ∣ x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T ( ∫ d x 1 : T ⋅ q ( x 1 : T ∣ x 0 ) q ( x t − 1 ∣ x t , x 0 ) ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T ( ∫ d x 1 : T ⋅ q ( x 0 : T ) q ( x 0 ) ⏟ q ( x 0 : T ) = q ( x 0 ) ⋅ q ( x 1 : T ∣ x 0 ) ⋅ q ( x t , x 0 ) q ( x t , x t − 1 , x 0 ) ⏟ q ( x t , x t − 1 , x 0 ) = q ( x t , x 0 ) ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T ( ∫ d x 1 : T ⋅ q ( x 0 : T ) q ( x 0 ) ⋅ q ( x t , x 0 ) q ( x t − 1 , x 0 ) ⋅ q ( x t ∣ x t − 1 , x 0 ) ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T ( ∫ [ ∫ q ( x 0 : T ) q ( x 0 ) ⋅ q ( x t , x 0 ) q ( x t − 1 , x 0 ) ⋅ q ( x t ∣ x t − 1 , x 0 ) ∏ k ≥ 1 , k ≠ t − 1 d x k ] ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( ∫ [ ∫ q ( x 0 : T ) q ( x t − 1 , x 0 ) ⋅ q ( x t , x 0 ) q ( x 0 ) ⋅ q ( x t ∣ x t − 1 , x 0 ) ∏ k ≥ 1 , k ≠ t − 1 d x k ] ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( ∫ [ ∫ q ( x k : k ≥ 1 , k ≠ t − 1 ∣ x t − 1 , x 0 ) ⏟ q ( x 0 ; T ) = q ( x t − 1 , x 0 ) ⋅ q ( x k : k ≥ 1 , k ≠ t − 1 ∣ x t − 1 , x 0 ) ⋅ q ( x t ∣ x 0 ) q ( x t ∣ x t − 1 , x 0 ) ⏟ q ( x t , x 0 ) = q ( x 0 ) ⋅ q ( x t ∣ x 0 ) ∏ k ≥ 1 , k ≠ t − 1 d x k ] ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( ∫ [ ∫ q ( x k : k ≥ 1 , k ≠ t − 1 ∣ x t − 1 , x 0 ) ⋅ q ( x t ∣ x 0 ) q ( x t ∣ x t − 1 , x 0 ) ⏟ = 1 ∏ k ≥ 1 , k ≠ t − 1 d x k ] ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( ∫ [ ∫ q ( x k : k ≥ 1 , k ≠ t − 1 ∣ x t − 1 , x 0 ) ⋅ ∏ k ≥ 1 , k ≠ t − 1 d x k ] ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( ∫ [ ∫ q ( x k : k ≥ 1 , k ≠ t − 1 ∣ x t − 1 , x 0 ) ⋅ ∏ k ≥ 1 , k ≠ t − 1 d x k ⏟ = 1 ] ⋅ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( ∫ q ( x t − 1 ∣ x t , x 0 ) ⋅ l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) d x t − 1 ] ) = ∑ t = 2 T ( E x t − 1 ∼ q ( x t − 1 ∣ x t , x 0 ) l o g [ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ] ) = ∑ t = 2 T K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) \begin{equation} \begin{split} L_2&=E_{x_{1:T} \sim q(x_{1:T} | x_0)} \Bigg(\sum_{t=2}^{T} log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} \Big]\Bigg)\\ &=\sum_{t=2}^{T} E_{x_{1:T} \sim q(x_{1:T} | x_0)} \Bigg(log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} \Big]\Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx_{1:T} \cdot q(x_{1:T}| x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx_{1:T} \cdot \frac{ q(x_{1:T}| x_0)}{q(x_{t-1}|x_t,x_0)} \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p(x_{t-1}|x_t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx_{1:T} \cdot \underbrace{ \frac{q(x_{0:T})}{q(x_0)}}_{q(x_{0:T})=q(x_0)\cdot q(x_{1:T}| x_0)} \cdot \underbrace{ \frac{q(x_t,x_0)}{q(x_t,x_{t-1},x_0)}}_{q(x_t,x_{t-1},x_0)=q(x_t,x_0)\cdot q(x_{t-1}|x_t,x_0)} \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx_{1:T} \cdot \frac{q(x_{0:T})}{q(x_0)}\cdot \frac{q(x_t,x_0)}{q(x_{t-1},x_0)\cdot q(x_t|x_{t-1},x_0)} \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[ \int \frac{q(x_{0:T})}{q(x_0)}\cdot \frac{q(x_t,x_0)}{q(x_{t-1},x_0)\cdot q(x_t|x_{t-1},x_0)} \prod_{k\geq1 ,k\neq t-1} dx_k \bigg] \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[ \int \frac{q(x_{0:T})}{q(x_{t-1},x_0)}\cdot \frac{q(x_t,x_0)}{q(x_0)\cdot q(x_t|x_{t-1},x_0)} \prod_{k\geq1 ,k\neq t-1} dx_k \bigg] \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[ \underbrace{ \int q(x_{k:k\geq1,k\neq t-1}|x_{t-1},x_0)}_{q(x_{0;T})=q(x_{t-1},x_0)\cdot q(x_{k:k\geq1,k\neq t-1}|x_{t-1},x_0)} \cdot \underbrace {\frac{q(x_t|x_0)}{ q(x_t|x_{t-1},x_0)}}_{q(x_t,x_0)=q(x_0)\cdot q(x_t|x_0)} \prod_{k\geq1 ,k\neq t-1} dx_k \bigg] \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[\int q(x_{k:k\geq1,k\neq t-1}|x_{t-1},x_0)\cdot \underbrace {\frac{q(x_t|x_0)}{ q(x_t|x_{t-1},x_0)}}_{=1} \prod_{k\geq1 ,k\neq t-1} dx_k \bigg] \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[\int q(x_{k:k\geq1,k\neq t-1}|x_{t-1},x_0)\cdot \prod_{k\geq1 ,k\neq t-1} dx^k \bigg] \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[\underbrace{ \int q(x_{k:k\geq1,k\neq t-1}|x_{t-1},x^0)\cdot \prod_{k\geq1 ,k\neq t-1} dx_k }_{=1}\bigg] \cdot q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int q(x_{t-1}|x_t,x_0) \cdot log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} dx_{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( E_{x_{t-1}\sim q(x_{t-1}|x_t,x_0)} log \Big[\frac{q(x_{t-1}|x_t,x_0)}{ p_{\theta}(x_{t-1}|x_t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T}KL\bigg(q(x_{t-1}|x_t,x_0)||p_{\theta}(x_{t-1}|x_t) \bigg) \end{split} \end{equation} L2=Ex1:Tq(x1:Tx0)(t=2Tlog[pθ(xt1xt)q(xt1xt,x0)])=t=2TEx1:Tq(x1:Tx0)(log[pθ(xt1xt)q(xt1xt,x0)])=t=2T(dx1:Tq(x1:Tx0)log[pθ(xt1xt)q(xt1xt,x0)])=t=2T(dx1:Tq(xt1xt,x0)q(x1:Tx0)q(xt1xt,x0)log[p(xt1xt)q(xt1xt,x0)])=t=2T(dx1:Tq(x0:T)=q(x0)q(x1:Tx0) q(x0)q(x0:T)q(xt,xt1,x0)=q(xt,x0)q(xt1xt,x0) q(xt,xt1,x0)q(xt,x0)q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)])=t=2T(dx1:Tq(x0)q(x0:T)q(xt1,x0)q(xtxt1,x0)q(xt,x0)q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)])=t=2T([q(x0)q(x0:T)q(xt1,x0)q(xtxt1,x0)q(xt,x0)k1,k=t1dxk]q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T([q(xt1,x0)q(x0:T)q(x0)q(xtxt1,x0)q(xt,x0)k1,k=t1dxk]q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T([q(x0;T)=q(xt1,x0)q(xk:k1,k=t1xt1,x0) q(xk:k1,k=t1xt1,x0)q(xt,x0)=q(x0)q(xtx0) q(xtxt1,x0)q(xtx0)k1,k=t1dxk]q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T([q(xk:k1,k=t1xt1,x0)=1 q(xtxt1,x0)q(xtx0)k1,k=t1dxk]q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T([q(xk:k1,k=t1xt1,x0)k1,k=t1dxk]q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T([=1 q(xk:k1,k=t1xt1,x0)k1,k=t1dxk]q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T(q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)dxt1])=t=2T(Ext1q(xt1xt,x0)log[pθ(xt1xt)q(xt1xt,x0)])=t=2TKL(q(xt1xt,x0)∣∣pθ(xt1xt))
最后考虑 L 3 L_3 L3,事实上,在论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》中提到为了防止边界效应,强制另 p ( x 0 ∣ x 1 ) = q ( x 1 ∣ x 0 ) p(x^0|x^1)=q(x^1|x^0) p(x0x1)=q(x1x0),因此这一项也是个常数。

由以上分析可知道,损失函数可以写为公式(3)。
L : = L 1 + L 2 + L 3 = K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) + ∑ t = 2 T K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) − l o g [ p θ ( x 0 ∣ x 1 ) ] \begin{equation} \begin{split} L&:=L_1+L_2+L_3 \\ &=KL\Big(q(x_T|x_0)||p(x_T)\Big) + \sum_{t=2}^{T}KL\bigg(q(x_{t-1}|x_t,x_0)||p_{\theta}(x_{t-1}|x_t) \bigg)-log \Big[p_{\theta}(x_{0}|x_1)\Big] \end{split} \end{equation} L:=L1+L2+L3=KL(q(xTx0)∣∣p(xT))+t=2TKL(q(xt1xt,x0)∣∣pθ(xt1xt))log[pθ(x0x1)]

忽略掉 L 1 L_1 L1 L 3 L_3 L3,损失函数可以写为公式(4)。
L : = ∑ t = 2 T K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) \begin{equation} \begin{split} L:=\sum_{t=2}^{T}KL\bigg(q(x_{t-1}|x_t,x_0)||p_{\theta}(x_{t-1}|x_t) \bigg) \end{split} \end{equation} L:=t=2TKL(q(xt1xt,x0)∣∣pθ(xt1xt))

可以看出 损失函数 L L L是两个高斯分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0) p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)的KL散度。 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)的均值和方差由论文阅读笔记:Denoising Diffusion Probabilistic Models (1)可知,分别为

σ 1 = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) μ 1 = 1 α t ⋅ ( x t − β t 1 − α t ˉ ⋅ z t ) 或者 μ 1 = α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ x 0 \begin{equation} \begin{split} \sigma_1&=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}}\\ \mu_1&=\frac{1}{\sqrt{\alpha_t}}\cdot (x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t) \\ 或者 \mu_1&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot x_0 \end{split} \end{equation} σ1μ1或者μ1=(1αtˉ)βt(1αt1ˉ) =αt 1(xt1αtˉ βtzt)=1αtˉαt (1αt1ˉ)xt+1αtˉβtαt1ˉ x0

p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)则由模型(深度学习模型或者其他模型)估算出其均值和方差,分别记作 μ 2 , σ 2 \mu_2,\sigma_2 μ2,σ2
因此损失函数 L L L可以进一步写为公式12。
L : = l o g [ σ 2 σ 1 ] + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 \begin{equation} \begin{split} L:=log \Big[\frac{\sigma_2}{\sigma_1}\Big]+\frac{\sigma_1^2 +(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} \end{split} \end{equation} L:=log[σ1σ2]+2σ22σ12+(μ1μ2)221

5、代码解析

最后结合原文中的代码diffusion-https://github.com/hojonathanho/diffusion来理解一下训练过程和推理过程。
首先是训练过程

class GaussianDiffusion2:"""Contains utilities for the diffusion model.Arguments:- what the network predicts (x_{t-1}, x_0, or epsilon)- which loss function (kl or unweighted MSE)- what is the variance of p(x_{t-1}|x_t) (learned, fixed to beta, or fixed to weighted beta)- what type of decoder, and how to weight its loss? is its variance learned too?"""# 模型中的一些定义def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):self.model_mean_type = model_mean_type  # xprev, xstart, epsself.model_var_type = model_var_type  # learned, fixedsmall, fixedlargeself.loss_type = loss_type  # kl, mseassert isinstance(betas, np.ndarray)self.betas = betas = betas.astype(np.float64)  # computations here in float64 for accuracyassert (betas > 0).all() and (betas <= 1).all()timesteps, = betas.shapeself.num_timesteps = int(timesteps)alphas = 1. - betasself.alphas_cumprod = np.cumprod(alphas, axis=0)self.alphas_cumprod_prev = np.append(1., self.alphas_cumprod[:-1])assert self.alphas_cumprod_prev.shape == (timesteps,)# calculations for diffusion q(x_t | x_{t-1}) and othersself.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)self.sqrt_one_minus_alphas_cumprod = np.sqrt(1. - self.alphas_cumprod)self.log_one_minus_alphas_cumprod = np.log(1. - self.alphas_cumprod)self.sqrt_recip_alphas_cumprod = np.sqrt(1. / self.alphas_cumprod)self.sqrt_recipm1_alphas_cumprod = np.sqrt(1. / self.alphas_cumprod - 1)# calculations for posterior q(x_{t-1} | x_t, x_0)self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chainself.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1. - self.alphas_cumprod)# 在模型Model类当中的方法def train_fn(self, x, y):B, H, W, C = x.shapeif self.randflip:x = tf.image.random_flip_left_right(x)assert x.shape == [B, H, W, C]# 随机生成第t步t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32)# 计算第t步时对应的损失函数losses = self.diffusion.training_losses(denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t)assert losses.shape == t.shape == [B]return {'loss': tf.reduce_mean(losses)}# 根据x_start采样到第t步的带噪图像def q_sample(self, x_start, t, noise=None):"""Diffuse the data (t == 0 means diffused for 1 step)"""if noise is None:noise = tf.random_normal(shape=x_start.shape)assert noise.shape == x_start.shapereturn (self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)# 计算q(x^{t-1}|x^t,x^0)分布的均值和方差def q_posterior_mean_variance(self, x_start, x_t, t):"""Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)"""assert x_start.shape == x_t.shapeposterior_mean = (self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==x_start.shape[0])return posterior_mean, posterior_variance, posterior_log_variance_clipped# 由深度学习模型UNet估算出p(x^{t-1}|x^t)分布的方差和均值def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool, return_pred_xstart: bool):B, H, W, C = x.shapeassert t.shape == [B]model_output = denoise_fn(x, t)# Learned or fixed variance?if self.model_var_type == 'learned':assert model_output.shape == [B, H, W, C * 2]model_output, model_log_variance = tf.split(model_output, 2, axis=-1)model_variance = tf.exp(model_log_variance)elif self.model_var_type in ['fixedsmall', 'fixedlarge']:# below: only log_variance is used in the KL computationsmodel_variance, model_log_variance = {# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood'fixedlarge': (self.betas, np.log(np.append(self.posterior_variance[1], self.betas[1:]))),'fixedsmall': (self.posterior_variance, self.posterior_log_variance_clipped),}[self.model_var_type]model_variance = self._extract(model_variance, t, x.shape) * tf.ones(x.shape.as_list())model_log_variance = self._extract(model_log_variance, t, x.shape) * tf.ones(x.shape.as_list())else:raise NotImplementedError(self.model_var_type)# Mean parameterization_maybe_clip = lambda x_: (tf.clip_by_value(x_, -1., 1.) if clip_denoised else x_)if self.model_mean_type == 'xprev':  # the model predicts x_{t-1}pred_xstart = _maybe_clip(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))model_mean = model_outputelif self.model_mean_type == 'xstart':  # the model predicts x_0pred_xstart = _maybe_clip(model_output)model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)elif self.model_mean_type == 'eps':  # the model predicts epsilonpred_xstart = _maybe_clip(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)else:raise NotImplementedError(self.model_mean_type)assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shapeif return_pred_xstart:return model_mean, model_variance, model_log_variance, pred_xstartelse:return model_mean, model_variance, model_log_variance# 损失函数的计算过程def training_losses(self, denoise_fn, x_start, t, noise=None):assert t.shape == [x_start.shape[0]]# 随机生成一个噪音if noise is None:noise = tf.random_normal(shape=x_start.shape, dtype=x_start.dtype)assert noise.shape == x_start.shape and noise.dtype == x_start.dtype# 将随机生成的噪音加到x_start上得到第t步的带噪图像x_t = self.q_sample(x_start=x_start, t=t, noise=noise)# 有两种损失函数的方法,'kl'和'mse',并且这两种方法差别并不明显。if self.loss_type == 'kl':  # the variational boundlosses = self._vb_terms_bpd(denoise_fn=denoise_fn, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, return_pred_xstart=False)elif self.loss_type == 'mse':  # unweighted MSEassert self.model_var_type != 'learned'target = {'xprev': self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],'xstart': x_start,'eps': noise}[self.model_mean_type]model_output = denoise_fn(x_t, t)assert model_output.shape == target.shape == x_start.shapelosses = nn.meanflat(tf.squared_difference(target, model_output))else:raise NotImplementedError(self.loss_type)assert losses.shape == t.shapereturn losses# 计算两个高斯分布的KL散度,代码中的logvar1,logvar2为方差的对数def normal_kl(mean1, logvar1, mean2, logvar2):"""KL divergence between normal distributions parameterized by mean and log-variance."""return 0.5 * (-1.0 + logvar2 - logvar1 + tf.exp(logvar1 - logvar2)+ tf.squared_difference(mean1, mean2) * tf.exp(-logvar2))# 使用'kl'方法计算损失函数def _vb_terms_bpd(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool):true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, x=x_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)kl = nn.meanflat(kl) / np.log(2.)decoder_nll = -utils.discretized_gaussian_log_likelihood(x_start, means=model_mean, log_scales=0.5 * model_log_variance)assert decoder_nll.shape == x_start.shapedecoder_nll = nn.meanflat(decoder_nll) / np.log(2.)# At the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))assert kl.shape == decoder_nll.shape == t.shape == [x_start.shape[0]]output = tf.where(tf.equal(t, 0), decoder_nll, kl)return (output, pred_xstart) if return_pred_xstart else output

接下来是推理过程。

def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool):"""Sample from the model"""# 使用深度学习模型,根据x^t和t估算出x^{t-1}的均值和分布model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, x=x, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)noise = noise_fn(shape=x.shape, dtype=x.dtype)assert noise.shape == x.shape# no noise when t == 0nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1))# 当t>0时,模型估算出的结果还要加上一个高斯噪音,因为要继续循环。当t=0时,循环停止,因此不需要再添加噪音了,输出最后的结果。sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noiseassert sample.shape == pred_xstart.shapereturn (sample, pred_xstart) if return_pred_xstart else sampledef p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal):"""Generate samples"""assert isinstance(shape, (tuple, list))# 生成总的布数Ti_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)# 随机生成一个噪音作为p(x^T)img_0 = noise_fn(shape=shape, dtype=tf.float32)# 循环T次,得到最终的图像_, img_final = tf.while_loop(cond=lambda i_, _: tf.greater_equal(i_, 0),body=lambda i_, img_: [i_ - 1,self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False)],loop_vars=[i_0, img_0],shape_invariants=[i_0.shape, img_0.shape],back_prop=False)assert img_final.shape == shapereturn img_final

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/38775.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ElasticSearch 可观测性最佳实践

ElasticSearch 概述 ElasticSearch 是一个开源的高扩展的分布式全文检索引擎&#xff0c;它可以近乎实时的存储、检索数据&#xff1b;本身扩展性很好&#xff0c;可以扩展到上百台服务器&#xff0c;处理 PB 级别&#xff08;大数据时代&#xff09;的数据。ES 也使用 Java 开…

操作系统的特征

并发 指两个或多个事件在同一时间间隔内发生。这些时间宏观上是同时发生的&#xff0c;但微观上是交替发生的。 并行 指两个或多个事件在同一时刻同时发生 操作系统的并发性 指计算机系统重“同时”运行着多个程序&#xff0c;这些程序宏观上看是同时运行的&#xff0c;而…

数据结构——B树、B+树、哈夫曼树

目录 一、B树概念1.B树的构造2 .B树的特点 二、B树概念1.B树构造2.B树的特点 三、B树和B树的区别四、哈夫曼树1.哈夫曼树的基本概念2.哈夫曼树的构建 一、B树概念 B树的出现是为了弥合不同的存储级别之间的访问速度上的巨大差异&#xff0c;实现高效的 I/O。平衡二叉树的查找效…

电子签的法律效力、业务合规与监管难点

撰稿 | 区长 来源 | 贝多财经 据2025年央视“3.15”晚会报道&#xff0c;借贷宝、人人信等平台上存在高利贷的情形。放贷人与借款人在平台签署借款合同&#xff0c;但是实际借款金额低于合同金额&#xff0c;从而绕开平台对利率的限制。这引发了人们对电子签法律效力、业务合…

资金管理策略思路

详细描述了完整交易策略的实现细节&#xff0c;主要包括输入参数、变量定义、趋势判断、入场与出场条件、止损与止盈设置等多个方面。 输入参数&#xff08;Input&#xff09;&#xff1a; EntryFrL (.6)&#xff1a;多头入场的前一日波动范围的倍数。 EntryFrS (.3)&#xff1…

体育直播视频源格式解析:M3U8 vs FLV

在体育直播领域&#xff0c;视频源的格式选择直接影响着直播的流畅度、画质以及兼容性。目前&#xff0c;M3U8 和 FLV 是两种最为常见的视频流格式&#xff0c;它们各有优劣&#xff0c;适用于不同的场景。本文将从技术原理、优缺点以及应用场景等方面对 M3U8 和 FLV 进行详细解…

【动态规划】下降路径最小和

跟之前不同由于可能取到最右上角值&#xff0c;则左右各加一列&#xff0c;并且由于求最小值&#xff0c;则加的列须设置为正无穷大&#xff1b; class Solution { public:int minFallingPathSum(vector<vector<int>>& matrix) {int nmatrix.size();vector<…

07_GRU模型

GRU模型 双向GRU笔记:https://blog.csdn.net/weixin_44579176/article/details/146459952 概念 GRU&#xff08;Gated Recurrent Unit&#xff09;也称为门控循环单元&#xff0c;是一种改进版的RNN。与LSTM一样能够有效捕捉长序列之间的语义关联&#xff0c;通过引入两个&qu…

VScode

由于centos停止了维护 ,后面使用ubuntu 在Ubuntu中用vscode 充当记事本的作用 替代了centos中vim的作用 后面使用vscode编辑 vscode中继续使用makefile , xshell中的cgdb进行debug (半图形写 ,半命令行debug&&运行) 官网下载地址&#xff1a;https://code.visuals…

【行驶证识别】批量咕嘎OCR识别行驶证照片复印件图片里的文字信息保存表格或改名字,基于QT和腾讯云api_ocr的实现方式

项目背景 在许多业务场景中,如物流管理、车辆租赁、保险理赔等,常常需要处理大量的行驶证照片复印件。手动录入行驶证上的文字信息,像车主姓名、车辆型号、车牌号码等,不仅效率低下,还容易出现人为错误。借助 OCR(光学字符识别)技术,能够自动识别行驶证图片中的文字信…

异步编程与流水线架构:从理论到高并发

目录 一、异步编程核心机制解析 1.1 同步与异步的本质区别 1.1.1 控制流模型 1.1.2 资源利用对比 1.2 阻塞与非阻塞的技术实现 1.2.1 阻塞I/O模型 1.2.2 非阻塞I/O模型 1.3 异步编程关键技术 1.3.1 事件循环机制 1.3.2 Future/Promise模式 1.3.3 协程&#xff08;Cor…

python-selenium 爬虫 由易到难

本质 python第三方库 selenium 控制 浏览器驱动 浏览器驱动控制浏览器 推荐 edge 浏览器驱动&#xff08;不容易遇到版本或者兼容性的问题&#xff09; 驱动下载网址&#xff1a;链接: link 1、实战1 &#xff08;1&#xff09;安装 selenium 库 pip install selenium&#…

前端OOM内存泄漏如何排查?

前言 现代前端开发中&#xff0c;随着应用的复杂性和交互性的增加&#xff0c;OOM&#xff08;Out Of Memory&#xff0c;内存不足&#xff09;问题和内存泄漏逐渐成为影响用户体验和应用性能的关键挑战。排查和解决这些问题需要开发人员具备良好的调试技巧和优化策略。 造成…

C++20:玩转 string 的 starts_with 和 ends_with

文章目录 一、背景与动机二、string::starts_with 和 string::ends_with&#xff08;一&#xff09;语法与功能&#xff08;二&#xff09;使用示例1\. 判断字符串开头2\. 判断字符串结尾 &#xff08;三&#xff09;优势 三、string_view::starts_with 和 string_view::ends_w…

Redis、Memcached应用场景对比

环境 Redis官方网站&#xff1a; Redis - The Real-time Data Platform Redis社区版本下载地址&#xff1a;Install Redis | Docs Memcached官方网站&#xff1a;memcached - a distributed memory object caching system Memcached下载地址&#xff1a;memcached - a dis…

【MySQL】日志

目录 基本概念错误日志二进制日志查询日记慢查询日志 基本概念 日志&#xff08;Log&#xff09;是系统、软件或设备在运行过程中对发生的事件、操作或状态变化所做的记录。这些记录通常包含时间戳、事件类型、相关数据等信息&#xff0c;用于跟踪运行过程、排查故障、审计操作…

ArkUI-List组件

列表是一个复杂的容器&#xff0c;当列表项达到一定数量&#xff0c;使得列表内容超出其范围的时候&#xff0c;就会自动变为可以滚动。列表适合用来展现同类数据类型。 List组件支持使用&#xff0c;条件渲染&#xff0c;循环渲染&#xff0c;懒加载等渲染控制方式生成子组件…

Word限定仅搜索中文或英文引号

在Word中&#xff0c;按下CtrlF键&#xff0c;左侧会弹出导航搜索栏&#xff1b; 点击放大镜旁边的下拉栏&#xff0c;选择高级查找 在查找内容处输入英文状态下的"&#xff0c;然后选择更多->使用通配符&#xff0c;就可以仅查找英文状态下的" 同理&#xff…

智能飞鸟监测 守护高压线安全

飞鸟检测新纪元&#xff1a;视觉分析技术的革新应用 在现代化社会中&#xff0c;飞鸟检测成为了多个领域不可忽视的重要环节。无论是高压线下的安全监测、工厂内的生产秩序维护&#xff0c;还是农业区的作物保护&#xff0c;飞鸟检测都扮演着至关重要的角色。传统的人工检测方…

React初学分享 事件绑定 组价通信 useState useEffect

React初学 React介绍快速搭建React项目JSXJSX的本质优势&#xff1a;JSX中使用JS表达式JSX中的列表渲染JSX实现简单条件渲染JSX实现复杂条件渲染 React中的事件绑定React基础事件绑定传递自定义参数同时传递事件对象和自定义参数 React中的组件useState修改状态的规则状态不可变…