写在前面
VAE(Variational Autoencoder),中文译为变分自编码器。其中AE(Autoencoder)很好理解。那“变分”指的是什么呢?—其实是“变分推断”。变分推断主要用在VAE的损失函数中,那变分推断是什么,VAE的损失函数又是什么呢?下面我就来说一说!
可以先看一下 这篇文章,介绍了VAE的代码实现。
一、通俗理解损失函数
这篇文章已经整体介绍了VAE,这里我详细介绍一下VAE的损失函数:
每个变量的说明下面会有介绍,现在我们只关注VAE的损失函数有由两部分组成,第一部分是一个交叉熵,我们称之为“重构项”,其作用是确保训练时输入和输出间的相似性;第二部分是KL散度,叫做“KL散度项”,它其实是一个正则项,主要解决了两个AE模型的痛点,这也是VAE成功并流行的主要原因:
1.潜在空间的结构化:AE的潜在空间往往是无规则的,这意味着编码器学到的表征可能杂乱无章,不便于后续操作。VAE通过添加KL散度项来惩罚潜在变量分布与预设先验分布(就是p(z),是一个标准高斯分布)之间的偏差,从而迫使潜在空间呈现出一定的结构,使潜在变量的分布更加合理和连贯。说人话就是:VAE可以输入标准高斯分布的采样数据,生成精美的图像。
2.潜在空间的连续性:KL散度项要求潜在变量 z 的分布 q(z|x) 尽可能接近预设的先验分布 p(z) ,这个先验分布通常选择为标准正态分布。通过这种方式,潜在空间被组织成一个连续、平滑的多维空间,其中每一维上的值都能够自由变动而不产生剧烈变化。这种设计确保了在潜在空间中的小步长移动会导致解码结果的轻微变化,从而实现了连续性。说人话就是:VAE可以通过微调输入的采样数据,一定程度上修改生成图像的属性。这也是造成“抽卡”的原因之一。
损失函数的这两项可以简单的这么理解,但是它其实是推导出来的,这就说来话长。感兴趣的小伙伴继续往下看。
二、边际似然
1.边际似然的定义
VAE 是一种生成模型,生成模型的核心任务是计算在给定潜在变量 z 的情况下生成观测数据 x 的概率。我们希望模型能够生成与真实数据分布相似的新数据,这一目标可以通过边际似然 p(x) 来实现。
其中z就是Latent;x是训练用的图像;p(x)是边际似然,也就是VAE的损失函数。
p(x)可以很好的衡量模型的生成能力。p(x)直接衡量了模型在生成数据方面的整体能力,因为它考虑了所有潜在的隐变量 z 对观测数据 x 的影响。高的p(x)意味着模型可以很好地解释数据,并且在生成新数据时表现出较强的能力。
具体来说,如果模型的边际似然高,说明模型在所有可能的隐变量 z 下生成观测数据的概率累加起来后非常高,这意味着模型学到了数据的真实分布。
边际似然 p(x) 表示给定模型情况下生成观测数据 x 的概率,定义为:
(1)
其中,条件概率 p(x∣z):给定潜在变量 z 的情况下,生成观测数据 x 的概率。先验分布 p(z):潜在变量 z 的分布,反映了我们对 z 的先验知识。
2.边际似然的推导
使用全概率公式,边际似然可以用全概率公式来定义,具体为:
(2)
这里 p(x,z)是 x 和 z 的联合分布。根据条件概率的定义,联合分布可以表示为:
(3)
因此,我们可以将边际似然表示为:
(4)
我们要做的就是最大化p(x),这里多说一句,最大化p(x)的目标是使得模型生成的总体概率分布 p(x) 更接近于真实数据分布。这样,模型生成的新样本就会与训练数据的分布一致。
直观理解:假设我们在训练一个模型生成手写数字图片。如果真实的数据集中 80% 是“1”,20% 是“2”,那么一个好的生成模型应该能够生成 80% 的“1”和 20% 的“2”。而不是让p(x)趋近于1.
3.边际似然的挑战
但是计算边际似然通常是一个复杂且困难的任务,原因包括:
(1)高维积分:在实际的应用中,潜在变量 z 通常是高维的。例如,如果 z 是 100 维的向量,那么积分就需要在 100 维的空间上进行。这种高维积分是非常复杂的,解析解几乎不可能得到。
(2)分布形式复杂:在生成模型中,条件分布 p(x∣z)和先验分布 p(z) 可能并不是简单的概率分布。例如,p(x∣z) 可能由一个深度神经网络参数化,计算时需要经过非线性激活函数和复杂的网络结构,这会让这个积分无法直接求解。
(3)数值计算的困难:计算边际似然时,需要对 z 的所有可能值进行积分,也就是计算出在所有潜在表示 z 上,生成数据 x 的所有可能性。现实中,z 的范围非常大,即使是连续的,也可能取值无穷多个,直接求解所有 z 的可能性几乎是不可能的。
举个例子,假设我们有一个简单的生成模型,其中:p(z) 是标准正态分布N(0,I)。p(x∣z) 是由一个深度神经网络生成的图像。直接计算边际似然意味着我们需要知道所有 z 的取值如何影响 x。如果 z 是 100 维向量,那么在 空间上对 z 进行积分(或采样)需要极大的计算资源。神经网络的非线性使得每个 p(x∣z) 的计算都很复杂,最终让直接计算积分变得不可行。
为了解决上面的问题,让模型可以正常训练,我们引入变分推断。
三、变分推断
1.变分推断的定义
变分推断是一种通过引入近似分布来解决无法直接计算复杂积分的问题的方法。在生成模型中,我们的目标是最大化观测数据的边际似然 p(x):
(5)
如前所述,这个积分通常很难直接计算,因此我们引入一个 近似后验分布(也叫变分分布,就是训练时模型的输出) q(z∣x),来代替无法直接求解的真实后验 p(z∣x)。变分推断的目标是让 q(z∣x) 尽可能地接近真实的 p(z∣x)。
(6)
通过这种重写,我们引入了 q(z∣x) 作为一个权重,这样我们可以在期望的形式下进行优化。我们现在有一个可以计算的表达式:
(7)
尽管重写了表达式,计算 p(x)依然困难,因为积分本身依然难解。因此,我们应用 Jensen 不等式(log是凸函数),将对数操作从积分外移到期望内部(这里的期望是由积分转化来的):
(8)
其中,Eq(z∣x)[⋅]表示在 q(z∣x) 分布下对 z 取期望。这一不等式说明,我们得到了一个对数边际似然的下界,即变分下界 (ELBO)。
2.变分下界ELBO
式子(8)右边的表达式即为变分下界(Evidence Lower Bound,),通常记作 ELBO,至此我们的目标也变成了最大化ELBO,从而间接地最大化边际似然 p(x)。式子(8)可以写成:
(9)
式子(9)右边可以展开成:
(10)
因为KL散度公式:
(11)
可以看到,式子(10)右边的第二项和第三项可以用KL散度代替:
(12)
最终,ELBO 可以写成如下式子,这也是VAE需要优化的损失函数:
(13)
ELBO 公式展示了两个部分:
重构项:表示模型生成数据的能力。
KL 散度项:作为正则化项,控制 q(z∣x) 和 p(z) 之间的差异。最小化这个项有助于使近似后验 q(z∣x) 尽量接近先验 p(z),从而促进模型的泛化能力。p(z)一般被设置成标准高斯分布。
最大化 ELBO 的意义:
优化目标:最大化 ELBO 实际上是希望在重构能力和潜在分布的正则化之间取得平衡。通过调整这两个部分,可以确保模型既能够良好地重构输入数据,又能够学习到有意义的潜在空间。
间接最大化边际似然:由于 ELBO 是边际似然的下界,最大化 ELBO 也会使得边际似然 p(x) 的值增加。
ELBO 在 VAE 中扮演着至关重要的角色,它将生成模型的目标与优化过程结合起来,使得模型能够在重构能力和潜在空间的正则化之间找到最佳平衡。通过最大化 ELBO,VAE 能够学习到有效的潜在表示,从而生成新样本。
四、代码实现中的公式
这篇文章介绍了VAE的代码实现,其中的损失函数是ELBO的具体实现,我们来看一下,具体是怎么实现的。
我们的目标是最大化ELBO,相当于最小化其负值,因此 VAE 的损失函数可以表示为:
(14)
1.重构项
交叉熵的定义为:
(15)
如果我们将 p(x∣z) 视为模型生成 x 的概率分布(对应代码中的recon_x,即模型的输出),而将真实数据的分布视为 q(x)(对应代码中的x,即GT),则ELBO的第一项可以写成:
(16)
最大化 ELBO 的第一项(重构项)实际上是最小化交叉熵损失,代码如下:
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
2.KL 散度项
对于高斯分布 和标准正态分布,我们可以将 KL散度计算分解为以下几个步骤:
(1)KL散度 的公式为:
(17)
解释一下变量的意义:
q(z∣x):这是给定输入 x时隐变量 z 的后验分布,通常由编码器生成。
p(z):这是隐变量 z 的先验分布,通常是标准高斯分布 N(0,1)。
比率 :这个比率表示后验分布与先验分布的相对关系,反映了后验分布相较于先验分布的“信息量”。
对数项:量化了 q(z∣x) 相较于 p(z) 的信息增益。正值表示后验分布相对于先验分布的增加的信息,而负值则表示信息的损失。
积分:通过对所有可能的 z进行积分,KL散度 计算了整个后验分布与先验分布之间的差异。
(2)将和带入(17):
(18)
(3)高斯分布的公式: 高斯分布的概率密度函数为:
(19)
而标准正态分布为:
(20)
(4)计算 KL散度: 将这些代入 K散度的公式中,最终可以简化得到:
(21)
(5)简化: 进一步简化后,得到:
(22)
(6)用对数方差表示: 在实现中,通常使用对数方差 来计算,这样可以避免数值稳定性问题,最终得到的 KL散度公式是:
(23)
KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var:
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
其中log_var
是对数方差,使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu
是均值;
五、总结
1.VAE中的“变分”指的是“变分推断”;
2.VAE的损失函数值最大化边际似然;
3.最大化边际似然几乎做不到,所以使用变分推断来简化计算;
4.使用变分推断后,训练通过最大化ELBO实现;
5.ELBO有两项:重构项和KL散度项。重构项的作用是确保训练时输入和输出间的相似性,就是传统的损失函数常用的东西;KL散度项是一个正则项,能确保潜在空间的结构化和连续性。
VAE就介绍到这,关注不迷路(*^__^*)