经典神经网络(9)VAE模型原理及其在MNIST数据集上的应用
-
图片生成领域来说,有四大主流生成模型:生成对抗模型(GAN)、变分自动编码器(VAE)、流模型(Flow based Model)、扩散模型(Diffusion Model)。
-
VAE 的encoder是学习一个概率分布,所以VAE也可以随机采样生成图片,但VAE图片还原效果很弱,生成的图像模糊,效果不如diffusion model。但VAE可以减少训练和推理时间,降低GPU硬件要求。
-
从2022年开始,主要爆火的图片生成模型是Diffusion Model(扩散模型)为主。在Stable Diffusion中利用VAE将原图512x512x3压缩到64x64x4的潜空间(Latent Space),通过在隐式表征(而不是完整图像)上进行扩散,可以在使用更少的内存的同时,减少UNet层数并加速图片的生成。与此同时,我们仍能把结果输入VAE的解码器,从而解码得到高分辨率图像,隐式表征极大降低了训练和推理成本。
-
VAE损失函数的推导过程中,同样用到了KL散度的概念,可以参考:信息量、熵、KL散度、交叉熵概念理解
1 自编码器(Auto-encoder,AE)
1.1 自编码器概述
-
自编码器是一种无监督的神经网络模型,可以用于数据的降维、特征提取和数据重建等任务。
-
它由编码器和解码器两部分组成(如下图所示):
- 编码器将输入数据压缩成低维特征向量,即编码Code;
- 解码器则将低维特征向量还原成原始数据。
- 在自编码器整个训练过程中,目标是最小化输入数据和重建数据之间的差异,以学习到更加有效的特征表示。
-
最简单的自动编码器是由线性层构成的,叫做线性自编码器(如下图所示)。
- 输出层的神经元数量往往与输入层的神经元数量一致;
- 网络架构往往呈对称性,且中间结构简单、两边结构复杂。
1.2 自编码器存在的问题
- 如下图所示,假设有两张训练图片,一张是全月图,一张是半月图,经过训练我们的自编码器模型已经能无损地还原这两张图片。
- 接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。
- 但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。
- 原因是:基本自编码器给定一张图片生成原始图片,从输入到输出都是确定的,没有任何随机的成分,为了使模型表现很好,在不断的迭代训练中,编码器的输出也就是解码器的输入会趋于确定,这样才能让解码器能生成与输入数据更接近的数据,以使损失变得更小。但是这就与生成器的初衷有悖了:
- 生成器的初衷实际上是为了生成更多
全新
的数据,而不是为了生成与输入数据更像
的数据。
- 生成器的初衷实际上是为了生成更多
1.3 AE模型在MNIST数据集上的应用
- 代码如下所示,ae_original_image.png是原始数据
- ae_image_encoder.png是经过编码器,解码器后得到的图片
- ae_image.png是将三个编码器得到的编码值进行平均得到的图片,可以看到是模糊且无法辨认的乱码图,这就是自编码器存在的问题。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_imageclass AE(nn.Module):def __init__(self):super().__init__()# 编码器self.encoder = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 10))# 解码器self.decoder = nn.Sequential(nn.Linear(10, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 784),)def forward(self, x):x = x.view(-1, 28 * 28) # [b,1,28,28] ———> [b,784]x = self.encoder(x) # [b,784] ———> [b,10]x = self.decoder(x) # [b,10] ———> [b,784]return xdef train(model, loss_fn, opt, epoch=200):for epoch in range(epoch):model.train()total_loss = 0.0for x, _ in trian_dl:x = x.to(device)y_pre = model(x)loss = loss_fn(y_pre, x.reshape(-1, 784))opt.zero_grad()loss.backward()opt.step()total_loss += lossbreakprint(f'epoch = {epoch + 1}, train loss = {total_loss / len(trian_dl): .4f}')torch.save(model.state_dict(), 'ae_model.pth')if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 1、读取数据集trian_dl = torch.utils.data.DataLoader(datasets.MNIST("/root/autodl-fs/data/minist",train=True,download=False,transform=transforms.Compose([transforms.ToTensor(),]),),batch_size=256,shuffle=True,num_workers=8)# 2、创建AE模型及优化器model = AE().to(device)opt = torch.optim.Adam(model.parameters(), lr=0.001)loss_fn = nn.MSELoss(reduction='mean')# 4、模型训练train(model, loss_fn, opt, epoch=100)# 5、模型推理# 加载模型model.load_state_dict(torch.load('ae_model.pth', map_location=device))bs = 3text_dl = torch.utils.data.DataLoader(datasets.MNIST("/root/autodl-fs/data/minist",train=False,download=False,transform=transforms.Compose([transforms.ToTensor(),]),),batch_size=bs,shuffle=False)# 获取编码for x, y in text_dl:model.eval()save_image(x, "ae_original_image.png")x = x.to(device)sample_encoder = model.encoder(x.reshape(-1, 784))print(sample_encoder)image_encoder = model.decoder(sample_encoder).reshape(-1, 1, 28, 28)save_image(image_encoder, "ae_image_encoder.png")sample_sum = torch.sum(sample_encoder, dim=0) / 2sample = sample_sum.tile(bs, 1)image = model.decoder(sample).reshape(-1, 1, 28, 28)save_image(image, "ae_image.png")break
2 变分自编码器(VAE)
2.1 如何解决自编码器的缺点
- 我们现在已经知道,自编码器生成的图片是模糊且无法辨认的乱码图。如何解决这个问题呢?
- 如下图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内。
- 在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。
- 然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。
- 由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。
- 不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点
- 为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。
- 在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。
这种将图像编码由离散变为连续的方法,就是变分自编码的核心思想。
2.2 VAE的整体架构概述
- 如下图所示,与普通自动编码器一样,VAE有编码器Encoder与解码器Decoder两大部分组成,原始图像从编码器输入,经编码器后形成
隐式表示(Latent Representation)
,之后隐式表示被输入到解码器、再复原回原始输入的结构。 - 然而,与普通AE不同的是,我们不会直接将Encoder编码后的结果传递给Decoder,而是要使得隐式表示满足既定分布(例如:正态分布)。
- 变分自动编码器的Encoder在输出时,并不会直接输出原始数据的隐式表示,而是会输出从原始数据提炼出的均值 μ 和标准差 σ 。
- 之后,我们需要建立
均值为 μ 、标准差为 σ 的正态分布,并从该正态分布中抽样出隐式表示z
,再将隐式表示z输入到Decoder中进行解码。 - 对隐式表示z而言,它传递给Decoder的就不是原始数据的信息,而只是与原始数据同均值、同标准差的分布中的信息了。
2.3 VAE的前向过程
2.3.1 前向过程(一个样本产生一个正态分布)
-
在auto-encoder中,编码器是直接产生一个编码的,但是在VAE中,为了给编码添加合适的噪音,编码器会输出两个编码:
-
下图有三个样本,因此存在三组均值和标准差,一组均值和标准差只能生成一个正态分布,而一个正态分布中只能抽选一个数字,这是变分自编码器抽样的基本规则。下图中,每个样本经过Encoder后只输出了一组均值和标准差,那z自然只能有一列,隐式空间 ( c 1 , c 2 , c 3 ) (c_1,c_2,c_3) (c1,c2,c3)的结构为(3,1)。
-
一个是原有编码 ( m 1 , m 2 , m 3 ) (m_1,m_2,m_3) (m1,m2,m3)(均值),
注意这里是一个样本会输出一个均值
; -
一个是控制噪音干扰程度的编码 ( σ 1 、 σ 2 、 σ 3 ) (\sigma_1、\sigma_2、\sigma_3) (σ1、σ2、σ3)(
注意这里是方差,下面推导过程用的是标准差
),第二个编码其实很好理解,就是为随机噪音 ( e 1 , e 2 , e 3 ) (e_1,e_2,e_3) (e1,e2,e3)分配权重。加上exp的目的是为了保证这个分配的权重是个正值,最后将原编码与噪音编码相加,就得到了VAE在code层的输出结果,即 c i = e σ i ∗ e i + m i c_i=e^{\sigma_i}*e_i+m_i ci=eσi∗ei+mi。 -
得到了VAE在code层的输出结果 c c c后,进入解码器Decoder中,最后得到output。
-
- 由上图可知,在损失函数方面,除了必要的重构损失外(让输出和输入相近),VAE还另外增添了一个损失函数(
上图下方的L2损失函数
),我们后面会详细推导此损失函数。
2.3.2 前向过程(一个样本产生多个正态分布)
- 在变分自动编码器的流程当中,均值和标准差是通过第一个神经网络Encoder训练出来的。我们不可能知道当前样本服从的真实分布的状态,因此这一推断过程自然可以根据不同的规则(Encoder中不同的权重)得出不同的结果。
- 如下图所示,我们可以令Encoder的输出层存在3个神经元,这样Encoder就会对每一个样本推断出三对不同的均值和标准差。
这个行为相当于对样本数据所属的原始分布进行估计,但给出了三个可能的答案
。因此现在,在每个样本下,我们就可以基于三个均值和标准差的组合生成三个不同的正态分布了。 - 隐式空间越大,隐式表示z所携带的信息自然也会越多,自动编码器的表现就可能变得更好,因此在实际使用变分自动编码器的过程中,
一个样本上至少都会生成10~100组均值和标准差,隐式表示z的结构一般也是较高维的矩阵
。
2.4 VAE损失函数的推导
2.4.1 从高斯混合模型到VAE
- VAE的理论基础是高斯混合模型,即任何一个数据的分布,都可以看作是若干高斯分布的叠加。
- 如图所示,如果P(X)代表一种分布的话,存在一种拆分方法能让它表示成图中若干浅蓝色曲线对应的高斯分布的叠加。
- 如下图所示,我们将编码换成一个连续变量z,为了计算方便,我们规定z服从标准正态分布(实际上并不一定要选用)。正如2.3前向过程所示,对于每一个采样点z,会有两个函数 u u u和 σ \sigma σ,分别决定z对应到的高斯分布的均值和方差,然后在积分域上所有的高斯分布的累加就成为了原始分布P(X)。
- 我们使用 p p p代表解码器;
- p ( x ∣ z ) p(x|z) p(x∣z)代表给定z时解码器产生 x x x的概率;
- x x x并非一个具体的值,而可以看作是一类数据,比如 x x x可以代表某种风格的手写体数字, p ( x ∣ z ) p(x|z) p(x∣z)就是生成这些数字的概率,这里的概率也并非一个具体的值,而是某一风格的每个数字对应了一个概率,其输出的是概率分布。
- 那么VAE的优化目标是什么呢?其实就是
最大化解码器输出x的概率
,即最大化 p ( x ) p(x) p(x)。
2.4.2 损失函数的推导
我们现在的优化目标就是:最大化编码器输出x的概率
p ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z p(x)=\int_zp(x|z)p(z)dz p(x)=∫zp(x∣z)p(z)dz
- 注意:这里的 p ( z ) p(z) p(z)可以是任意分布,在VAE中我们常常假设 p ( z ) p(z) p(z)服从标准正态分布
为了最大化 p ( x ) p(x) p(x),我们可以采用极大似然估计的方法来进行:
L = ∑ x l o g p ( x ) L=\sum_xlogp(x) L=x∑logp(x)
- 这里的每个x可以理解为代表了某一个风格的手写体,我们的目标是生成手写体数字,因此我们并不会局限其风格。
由于最大化L相当于最大化 l o g p ( x ) logp(x) logp(x),因此后续目标调整为最大化 l o g p ( x ) logp(x) logp(x)。我们假设q代表了编码器, q ( z ∣ x ) q(z|x) q(z∣x)就代表了给定x时编码器产生z的概率
给定任意 x ,其产生不同 z 的概率之和为 1 ,因此: ∫ z q ( z ∣ x ) = 1 而 p ( x ) 和 z 无关,那么 : l o g p ( x ) = ∫ z q ( z ∣ x ) l o g p ( x ) d z ( 公式一 ) 依据联合概率公式 p ( x ) = p ( x , z ) p ( z ∣ x ) = p ( x , z ) q ( z ∣ x ) q ( z ∣ x ) p ( z ∣ x ) 代入 ( 公式一 ) : l o g p ( x ) = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) q ( z ∣ x ) p ( z ∣ x ) ) d z 我们将 l o g 里的乘积拆开,变为两项之和 = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) p ( z ∣ x ) ) d z 给定任意x,其产生不同z的概率之和为1,因此:\\ \int_zq(z|x)=1\\ 而p(x)和z无关,那么:\\ logp(x)=\int_zq(z|x)logp(x)dz(公式一) \\ 依据联合概率公式p(x)=\frac{p(x,z)}{p(z|x)}=\frac{p(x,z)}{q(z|x)}\frac{q(z|x)}{p(z|x)}\\ 代入(公式一):logp(x)=\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)}\frac{q(z|x)}{p(z|x)})dz\\ 我们将log里的乘积拆开,变为两项之和\\ =\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)})dz+\int_zq(z|x)log(\frac{q(z|x)}{p(z|x)})dz 给定任意x,其产生不同z的概率之和为1,因此:∫zq(z∣x)=1而p(x)和z无关,那么:logp(x)=∫zq(z∣x)logp(x)dz(公式一)依据联合概率公式p(x)=p(z∣x)p(x,z)=q(z∣x)p(x,z)p(z∣x)q(z∣x)代入(公式一):logp(x)=∫zq(z∣x)log(q(z∣x)p(x,z)p(z∣x)q(z∣x))dz我们将log里的乘积拆开,变为两项之和=∫zq(z∣x)log(q(z∣x)p(x,z))dz+∫zq(z∣x)log(p(z∣x)q(z∣x))dz
结合KL散度公式,我们可以看出第二项其实就是 K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) KL(q(z|x) || p(z|x)) KL(q(z∣x)∣∣p(z∣x))。因为该值为非负项,所以 l o g p ( x ) logp(x) logp(x)不可能小于第一项,我们使用 L b Lb Lb来代表第一项。
l o g p ( x ) = L b + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) p ( z ∣ x ) ) d z = L b + K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 结合式式子 : p ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z 分析如下: 当 p ( x ∣ z ) 不变时, p ( x ) 也不变,从而 l o g p ( x ) 也不变,那么 L b + K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 的值就不会变。 这时如果我们利用 q ( z ∣ x ) 来最大化 L b ,那么 L b 就会增大,而 K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 的值就会减小。 当我们调节到 q ( z ∣ x ) 与 p ( z ∣ x ) 完全相同时, K L 散度就为 0 , L b 和 l o g p ( x ) 完全一致 logp(x)=Lb+\int_zq(z|x)log(\frac{q(z|x)}{p(z|x)})dz\\ =Lb+KL(q(z|x) || p(z|x))\\ 结合式式子:p(x)=\int_zp(x|z)p(z)dz\\ 分析如下:\\ 当p(x|z)不变时,p(x)也不变,从而log p(x)也不变,那么Lb+KL(q(z|x) || p(z|x))的值就不会变。\\ 这时如果我们利用q(z|x)来最大化Lb,那么Lb就会增大,而KL(q(z|x) || p(z|x))的值就会减小。\\ 当我们调节到q(z|x)与p(z|x)完全相同时,KL散度就为0,Lb和logp(x)完全一致\\ logp(x)=Lb+∫zq(z∣x)log(p(z∣x)q(z∣x))dz=Lb+KL(q(z∣x)∣∣p(z∣x))结合式式子:p(x)=∫zp(x∣z)p(z)dz分析如下:当p(x∣z)不变时,p(x)也不变,从而logp(x)也不变,那么Lb+KL(q(z∣x)∣∣p(z∣x))的值就不会变。这时如果我们利用q(z∣x)来最大化Lb,那么Lb就会增大,而KL(q(z∣x)∣∣p(z∣x))的值就会减小。当我们调节到q(z∣x)与p(z∣x)完全相同时,KL散度就为0,Lb和logp(x)完全一致
那么如果 q ( z ∣ x ) 不变呢?此时当我们增大 p ( x ∣ z ) 时, L b 会增大且 p ( x ) 会增大,即 l o g p ( x ) 也会增大。 由此我们可以得出结论,只要我们最大化 L b 就能使 l o g p ( x ) 最大化。 那么如果q(z|x)不变呢?此时当我们增大p(x|z)时,Lb会增大且p(x)会增大,即log p(x)也会增大。\\ 由此我们可以得出结论,只要我们最大化Lb就能使log p(x)最大化。 那么如果q(z∣x)不变呢?此时当我们增大p(x∣z)时,Lb会增大且p(x)会增大,即logp(x)也会增大。由此我们可以得出结论,只要我们最大化Lb就能使logp(x)最大化。
此时我们的优化目标就变成了最大化 L b Lb Lb
L b = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( p ( x ∣ z ) p ( z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g p ( z ) q ( z ∣ x ) d z + ∫ z q ( z ∣ x ) l o g p ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) + ∫ z q ( z ∣ x ) l o g p ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) + E q ( z ∣ x ) l o g p ( x ∣ z ) 我们加负号,转化为最小化问题: L o s s = K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) − E q ( z ∣ x ) l o g p ( x ∣ z ) Lb=\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)})dz\\ =\int_zq(z|x)log(\frac{p(x|z)p(z)}{q(z|x)})dz\\ =\int_zq(z|x)log\frac{p(z)}{q(z|x)}dz+\int_zq(z|x)logp(x|z)dz\\ =-KL(q(z|x)||p(z)) + \int_zq(z|x)logp(x|z)dz\\ =-KL(q(z|x)||p(z)) + E_{q(z|x)}logp(x|z)\\ 我们加负号,转化为最小化问题:\\ Loss=KL(q(z|x)||p(z)) - E_{q(z|x)}logp(x|z) Lb=∫zq(z∣x)log(q(z∣x)p(x,z))dz=∫zq(z∣x)log(q(z∣x)p(x∣z)p(z))dz=∫zq(z∣x)logq(z∣x)p(z)dz+∫zq(z∣x)logp(x∣z)dz=−KL(q(z∣x)∣∣p(z))+∫zq(z∣x)logp(x∣z)dz=−KL(q(z∣x)∣∣p(z))+Eq(z∣x)logp(x∣z)我们加负号,转化为最小化问题:Loss=KL(q(z∣x)∣∣p(z))−Eq(z∣x)logp(x∣z)
- 此时VAE的最终目标就一目了然了,VAE的训练目标有两个:
- 第一,最大化在 q ( z ∣ x ) q(z|x) q(z∣x)这个分布下 l o g p ( x ∣ z ) logp(x|z) logp(x∣z)的期望(L1损失),其中 q ( z ∣ x ) q(z|x) q(z∣x)为编码器输入 x x x时产生 z z z的概率。假设解码器利用 z z z生成出了 x ’ x’ x’,我们就需要使 x ’ x’ x’尽可能向 x x x靠近,以最大化 l o g p ( x ∣ z ) logp(x|z) logp(x∣z)。
- 第二,最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(z∣x)∣∣p(z))(L2损失),使 q ( z ∣ x ) q(z|x) q(z∣x)的分布尽量向 p ( z ) p(z) p(z)靠近;
根据上述的两个训练目标,VAE的损失函数也被设计为两个:
-
L1损失目的是输出的 x ’ x’ x’尽可能向原始 x x x靠近(即重构损失),我们可以最小化x’和x之间的MSE Loss或者BCE Loss。
-
L2用于最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(z∣x)∣∣p(z)),VAE假设 q ( z ∣ x ) q(z|x) q(z∣x)的分布为正态分布,而 p ( z ) p(z) p(z)为标准正态分布。计算两个正态分布之间的KL散度的公式如下(这里就不推导了,直接给出):
由于此处 p ( z ) p(z) p(z)为标准正态分布,因此其 μ 2 μ_2 μ2为0, σ 2 σ_2 σ2为1,那么我们带入后可得
K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = − 1 2 ( l o g σ 2 − σ 2 − u 2 + 1 ) 其中 σ 为 q ( z ∣ x ) 的标准差, μ 为 q ( z ∣ x ) 的均值 KL(q(z|x)||p(z))=-\frac{1}{2}(log\sigma^2-\sigma^2-u^2+1) \\ 其中σ为q(z|x)的标准差,μ为q(z|x)的均值 KL(q(z∣x)∣∣p(z))=−21(logσ2−σ2−u2+1)其中σ为q(z∣x)的标准差,μ为q(z∣x)的均值
2.4.3 VAE的重参数化
蓝色节点为采样节点,左侧由于采样的存在无法对 μ μ μ和 σ σ σ求导,反向传播无法进行,而右侧由于采用的重参数化技巧,灰色节点全部打通,使得网络能够正常进行反向传播。
2.4.4 VAE和GAN的区别、VAE的本质
可以参考:变分自编码器(一):原来是这么一回事 - 科学空间
2.5 VAE在MNIST数据集上的应用
- 解码器输出的不是方差 σ 2 σ^2 σ2 ,而是对数方差 l o g σ 2 logσ^2 logσ2,详见下面代码的encoder函数,这么做的原因就是,神经网络的输出是 [−∞,+∞] 的任意数值,但是方差不可能为负数,所以对方差取对数以满足神经网络输出值域的要求。
- 重参数化详见reparameter函数。
- 训练时候不要忘了损失函数的KL散度部分,利用上面推导的公式计算。这里由于模型的输出是对数方差 l o g σ 2 logσ^2 logσ2而不是方差,所以原始的计算公式需要做一个转换,如下:
K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = − 1 2 ( l o g σ 2 − σ 2 − u 2 + 1 ) 【其中 σ 为 q ( z ∣ x ) 的标准差, μ 为 q ( z ∣ x ) 的均值】 = − 1 2 ( l o g _ v a r − e l o g _ v a r − m u 2 + 1 ) KL(q(z|x)||p(z))\\ =-\frac{1}{2}(log\sigma^2-\sigma^2-u^2+1)【其中σ为q(z|x)的标准差,μ为q(z|x)的均值】 \\ =-\frac{1}{2}(log\_var-e^{log\_var}-mu^2+1) KL(q(z∣x)∣∣p(z))=−21(logσ2−σ2−u2+1)【其中σ为q(z∣x)的标准差,μ为q(z∣x)的均值】=−21(log_var−elog_var−mu2+1)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_imageclass VAE(nn.Module):def __init__(self):super().__init__()# encoderself.encoder_layer = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),)self.fc1 = nn.Linear(128, 10) # 均值self.fc2 = nn.Linear(128, 10) # log方差# decoderself.decoder = nn.Sequential(nn.Linear(10, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 784),)def encoder(self, x):x = F.relu(self.encoder_layer(x)) # [b,784] ———> [b,128]mu = self.fc1(x) # [b,128] ———> [b,10]log_var = self.fc2(x) # [b,128] ———> [b,10]return mu, log_vardef reparameter(self, mu, log_var):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 重参数化技巧 z = mu(均值) + eps(0-1正态分布) * sigma(方差)std = torch.sqrt(torch.exp(log_var))eps = torch.rand_like(std)z = mu + eps * std # [b, 10]return zdef forward(self, x):x = x.view(-1, 28 * 28) # [b,1,28,28] ———> [b,784]mu, log_var = self.encoder(x) # [b,784] ———> [b,10] mu sigmaz = self.reparameter(mu, log_var) # [b,10] ———> [b,10]x = self.decoder(z) # [b,10] ———> [b,784]x_hat = x.reshape(-1, 1, 28, 28)return x_hat, mu, log_vardef train(model, epoch=200):for epoch in range(epoch):model.train()total_loss = 0.0total_bce_loss = 0.0total_kl_loss = 0.0for x, y in trian_dl:x, y = x.to(device), y.to(device)x_hat, mu, log_var = model(x)# 3-1、bce_lossbce_loss = F.binary_cross_entropy(torch.sigmoid(x_hat.view(-1, 784)), x.view(-1, 784), reduction='sum')# 3-2、kl_losskl_loss = torch.sum(-0.5 * (log_var - torch.exp(log_var) - mu ** 2 + 1))loss = bce_loss + kl_lossopt.zero_grad()loss.backward()opt.step()total_loss += losstotal_bce_loss += bce_losstotal_kl_loss += kl_lossprint(f'epoch = {epoch + 1},bce_loss = {(total_bce_loss / len(trian_dl)):.4f}, kl_loss = {(total_kl_loss / len(trian_dl)):.4f}, train loss = {total_loss / len(trian_dl):.4f}')torch.save(model.state_dict(), 'vae_model.pth')if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 1、读取数据集trian_dl = torch.utils.data.DataLoader(datasets.MNIST("/root/autodl-fs/data/minist",train=True,download=False,transform=transforms.Compose([transforms.ToTensor(),]),),batch_size=256,shuffle=True,num_workers=12)# 2、创建VAE模型及优化器model = VAE().to(device)opt = torch.optim.Adam(model.parameters(), lr=0.001)# 4、模型训练train(model, epoch=300)# 5、模型推理# 加载模型model.load_state_dict(torch.load('vae_model.pth', map_location=device))model.eval()z = torch.randn(3, 10) # 生成一个形状为 (3, 10) 的随机数张量image = model.decoder(z).reshape(-1, 1, 28, 28)save_image(image, "vae_image_random.png")
参考:
图片来自李宏毅老师的教程视频:https://www.bilibili.com/video/av15889450/?p=33
https://kexue.fm/archives/5253
https://blog.csdn.net/weixin_42491648/article/details/132384913
http://www.gwylab.com/note-vae.html