经典神经网络(9)VAE模型原理及其在MNIST数据集上的应用

经典神经网络(9)VAE模型原理及其在MNIST数据集上的应用

  • 图片生成领域来说,有四大主流生成模型:生成对抗模型(GAN)、变分自动编码器(VAE)、流模型(Flow based Model)、扩散模型(Diffusion Model)。

  • VAE 的encoder是学习一个概率分布,所以VAE也可以随机采样生成图片,但VAE图片还原效果很弱,生成的图像模糊,效果不如diffusion model。但VAE可以减少训练和推理时间,降低GPU硬件要求。

  • 从2022年开始,主要爆火的图片生成模型是Diffusion Model(扩散模型)为主。在Stable Diffusion中利用VAE将原图512x512x3压缩到64x64x4的潜空间(Latent Space),通过在隐式表征(而不是完整图像)上进行扩散,可以在使用更少的内存的同时,减少UNet层数并加速图片的生成。与此同时,我们仍能把结果输入VAE的解码器,从而解码得到高分辨率图像,隐式表征极大降低了训练和推理成本。

    在这里插入图片描述

  • VAE损失函数的推导过程中,同样用到了KL散度的概念,可以参考:信息量、熵、KL散度、交叉熵概念理解

1 自编码器(Auto-encoder,AE)

1.1 自编码器概述

  • 自编码器是一种无监督的神经网络模型,可以用于数据的降维、特征提取和数据重建等任务。

  • 它由编码器和解码器两部分组成(如下图所示):

    • 编码器将输入数据压缩成低维特征向量,即编码Code;
    • 解码器则将低维特征向量还原成原始数据。
    • 在自编码器整个训练过程中,目标是最小化输入数据和重建数据之间的差异,以学习到更加有效的特征表示。
  • 最简单的自动编码器是由线性层构成的,叫做线性自编码器(如下图所示)。

    • 输出层的神经元数量往往与输入层的神经元数量一致;
    • 网络架构往往呈对称性,且中间结构简单、两边结构复杂。

在这里插入图片描述

1.2 自编码器存在的问题

  • 如下图所示,假设有两张训练图片,一张是全月图,一张是半月图,经过训练我们的自编码器模型已经能无损地还原这两张图片。
  • 接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。
  • 但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。
  • 原因是:基本自编码器给定一张图片生成原始图片,从输入到输出都是确定的,没有任何随机的成分,为了使模型表现很好,在不断的迭代训练中,编码器的输出也就是解码器的输入会趋于确定,这样才能让解码器能生成与输入数据更接近的数据,以使损失变得更小。但是这就与生成器的初衷有悖了:
    • 生成器的初衷实际上是为了生成更多全新的数据,而不是为了生成与输入数据更像的数据。

在这里插入图片描述

1.3 AE模型在MNIST数据集上的应用

  • 代码如下所示,ae_original_image.png是原始数据

在这里插入图片描述

  • ae_image_encoder.png是经过编码器,解码器后得到的图片

在这里插入图片描述

  • ae_image.png是将三个编码器得到的编码值进行平均得到的图片,可以看到是模糊且无法辨认的乱码图,这就是自编码器存在的问题。

在这里插入图片描述

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_imageclass AE(nn.Module):def __init__(self):super().__init__()# 编码器self.encoder = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 10))# 解码器self.decoder = nn.Sequential(nn.Linear(10, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 784),)def forward(self, x):x = x.view(-1, 28 * 28)  # [b,1,28,28] ———> [b,784]x = self.encoder(x)      # [b,784] ———> [b,10]x = self.decoder(x)      # [b,10] ———> [b,784]return xdef train(model, loss_fn, opt, epoch=200):for epoch in range(epoch):model.train()total_loss = 0.0for x, _ in trian_dl:x = x.to(device)y_pre = model(x)loss = loss_fn(y_pre, x.reshape(-1, 784))opt.zero_grad()loss.backward()opt.step()total_loss += lossbreakprint(f'epoch = {epoch + 1}, train loss = {total_loss / len(trian_dl): .4f}')torch.save(model.state_dict(), 'ae_model.pth')if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 1、读取数据集trian_dl = torch.utils.data.DataLoader(datasets.MNIST("/root/autodl-fs/data/minist",train=True,download=False,transform=transforms.Compose([transforms.ToTensor(),]),),batch_size=256,shuffle=True,num_workers=8)# 2、创建AE模型及优化器model = AE().to(device)opt = torch.optim.Adam(model.parameters(), lr=0.001)loss_fn = nn.MSELoss(reduction='mean')# 4、模型训练train(model, loss_fn, opt, epoch=100)# 5、模型推理# 加载模型model.load_state_dict(torch.load('ae_model.pth', map_location=device))bs = 3text_dl = torch.utils.data.DataLoader(datasets.MNIST("/root/autodl-fs/data/minist",train=False,download=False,transform=transforms.Compose([transforms.ToTensor(),]),),batch_size=bs,shuffle=False)# 获取编码for x, y in text_dl:model.eval()save_image(x, "ae_original_image.png")x = x.to(device)sample_encoder = model.encoder(x.reshape(-1, 784))print(sample_encoder)image_encoder = model.decoder(sample_encoder).reshape(-1, 1, 28, 28)save_image(image_encoder, "ae_image_encoder.png")sample_sum = torch.sum(sample_encoder, dim=0) / 2sample = sample_sum.tile(bs, 1)image = model.decoder(sample).reshape(-1, 1, 28, 28)save_image(image, "ae_image.png")break

2 变分自编码器(VAE)

2.1 如何解决自编码器的缺点

  • 我们现在已经知道,自编码器生成的图片是模糊且无法辨认的乱码图。如何解决这个问题呢?
    • 如下图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内。
    • 在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。
    • 然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。

在这里插入图片描述

  • 由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。
    • 不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点
    • 为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。
    • 在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。

在这里插入图片描述

  • 这种将图像编码由离散变为连续的方法,就是变分自编码的核心思想。

2.2 VAE的整体架构概述

  • 如下图所示,与普通自动编码器一样,VAE有编码器Encoder与解码器Decoder两大部分组成,原始图像从编码器输入,经编码器后形成隐式表示(Latent Representation),之后隐式表示被输入到解码器、再复原回原始输入的结构。
  • 然而,与普通AE不同的是,我们不会直接将Encoder编码后的结果传递给Decoder,而是要使得隐式表示满足既定分布(例如:正态分布)。
    • 变分自动编码器的Encoder在输出时,并不会直接输出原始数据的隐式表示,而是会输出从原始数据提炼出的均值 μ 和标准差 σ 。
    • 之后,我们需要建立均值为 μ 、标准差为 σ 的正态分布,并从该正态分布中抽样出隐式表示z,再将隐式表示z输入到Decoder中进行解码。
    • 对隐式表示z而言,它传递给Decoder的就不是原始数据的信息,而只是与原始数据同均值、同标准差的分布中的信息了。

在这里插入图片描述

2.3 VAE的前向过程

2.3.1 前向过程(一个样本产生一个正态分布)

  • 在auto-encoder中,编码器是直接产生一个编码的,但是在VAE中,为了给编码添加合适的噪音,编码器会输出两个编码:

    • 下图有三个样本,因此存在三组均值和标准差,一组均值和标准差只能生成一个正态分布,而一个正态分布中只能抽选一个数字,这是变分自编码器抽样的基本规则。下图中,每个样本经过Encoder后只输出了一组均值和标准差,那z自然只能有一列,隐式空间 ( c 1 , c 2 , c 3 ) (c_1,c_2,c_3) (c1,c2,c3)的结构为(3,1)。

    • 一个是原有编码 ( m 1 , m 2 , m 3 ) (m_1,m_2,m_3) (m1,m2,m3)(均值),注意这里是一个样本会输出一个均值

    • 一个是控制噪音干扰程度的编码 ( σ 1 、 σ 2 、 σ 3 ) (\sigma_1、\sigma_2、\sigma_3) (σ1σ2σ3)(注意这里是方差,下面推导过程用的是标准差),第二个编码其实很好理解,就是为随机噪音 ( e 1 , e 2 , e 3 ) (e_1,e_2,e_3) (e1,e2,e3)分配权重。加上exp的目的是为了保证这个分配的权重是个正值,最后将原编码与噪音编码相加,就得到了VAE在code层的输出结果,即 c i = e σ i ∗ e i + m i c_i=e^{\sigma_i}*e_i+m_i ci=eσiei+mi

    • 得到了VAE在code层的输出结果 c c c后,进入解码器Decoder中,最后得到output。

在这里插入图片描述

  • 由上图可知,在损失函数方面,除了必要的重构损失外(让输出和输入相近),VAE还另外增添了一个损失函数(上图下方的L2损失函数),我们后面会详细推导此损失函数。

2.3.2 前向过程(一个样本产生多个正态分布)

  • 在变分自动编码器的流程当中,均值和标准差是通过第一个神经网络Encoder训练出来的。我们不可能知道当前样本服从的真实分布的状态,因此这一推断过程自然可以根据不同的规则(Encoder中不同的权重)得出不同的结果。
  • 如下图所示,我们可以令Encoder的输出层存在3个神经元,这样Encoder就会对每一个样本推断出三对不同的均值和标准差。这个行为相当于对样本数据所属的原始分布进行估计,但给出了三个可能的答案。因此现在,在每个样本下,我们就可以基于三个均值和标准差的组合生成三个不同的正态分布了。
  • 隐式空间越大,隐式表示z所携带的信息自然也会越多,自动编码器的表现就可能变得更好,因此在实际使用变分自动编码器的过程中,一个样本上至少都会生成10~100组均值和标准差,隐式表示z的结构一般也是较高维的矩阵

在这里插入图片描述

在这里插入图片描述

2.4 VAE损失函数的推导

2.4.1 从高斯混合模型到VAE

  • VAE的理论基础是高斯混合模型,即任何一个数据的分布,都可以看作是若干高斯分布的叠加。
  • 如图所示,如果P(X)代表一种分布的话,存在一种拆分方法能让它表示成图中若干浅蓝色曲线对应的高斯分布的叠加。

在这里插入图片描述

  • 如下图所示,我们将编码换成一个连续变量z,为了计算方便,我们规定z服从标准正态分布(实际上并不一定要选用)。正如2.3前向过程所示,对于每一个采样点z,会有两个函数 u u u σ \sigma σ,分别决定z对应到的高斯分布的均值和方差,然后在积分域上所有的高斯分布的累加就成为了原始分布P(X)。
    • 我们使用 p p p代表解码器;
    • p ( x ∣ z ) p(x|z) p(xz)代表给定z时解码器产生 x x x的概率;
    • x x x并非一个具体的值,而可以看作是一类数据,比如 x x x可以代表某种风格的手写体数字, p ( x ∣ z ) p(x|z) p(xz)就是生成这些数字的概率,这里的概率也并非一个具体的值,而是某一风格的每个数字对应了一个概率,其输出的是概率分布。

在这里插入图片描述

  • 那么VAE的优化目标是什么呢?其实就是最大化解码器输出x的概率,即最大化 p ( x ) p(x) p(x)

2.4.2 损失函数的推导

我们现在的优化目标就是:最大化编码器输出x的概率
p ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z p(x)=\int_zp(x|z)p(z)dz p(x)=zp(xz)p(z)dz

  • 注意:这里的 p ( z ) p(z) p(z)可以是任意分布,在VAE中我们常常假设 p ( z ) p(z) p(z)服从标准正态分布

为了最大化 p ( x ) p(x) p(x),我们可以采用极大似然估计的方法来进行:
L = ∑ x l o g p ( x ) L=\sum_xlogp(x) L=xlogp(x)

  • 这里的每个x可以理解为代表了某一个风格的手写体,我们的目标是生成手写体数字,因此我们并不会局限其风格。

由于最大化L相当于最大化 l o g p ( x ) logp(x) logp(x),因此后续目标调整为最大化 l o g p ( x ) logp(x) logp(x)。我们假设q代表了编码器, q ( z ∣ x ) q(z|x) q(zx)就代表了给定x时编码器产生z的概率
给定任意 x ,其产生不同 z 的概率之和为 1 ,因此: ∫ z q ( z ∣ x ) = 1 而 p ( x ) 和 z 无关,那么 : l o g p ( x ) = ∫ z q ( z ∣ x ) l o g p ( x ) d z ( 公式一 ) 依据联合概率公式 p ( x ) = p ( x , z ) p ( z ∣ x ) = p ( x , z ) q ( z ∣ x ) q ( z ∣ x ) p ( z ∣ x ) 代入 ( 公式一 ) : l o g p ( x ) = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) q ( z ∣ x ) p ( z ∣ x ) ) d z 我们将 l o g 里的乘积拆开,变为两项之和 = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) p ( z ∣ x ) ) d z 给定任意x,其产生不同z的概率之和为1,因此:\\ \int_zq(z|x)=1\\ 而p(x)和z无关,那么:\\ logp(x)=\int_zq(z|x)logp(x)dz(公式一) \\ 依据联合概率公式p(x)=\frac{p(x,z)}{p(z|x)}=\frac{p(x,z)}{q(z|x)}\frac{q(z|x)}{p(z|x)}\\ 代入(公式一):logp(x)=\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)}\frac{q(z|x)}{p(z|x)})dz\\ 我们将log里的乘积拆开,变为两项之和\\ =\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)})dz+\int_zq(z|x)log(\frac{q(z|x)}{p(z|x)})dz 给定任意x,其产生不同z的概率之和为1,因此:zq(zx)=1p(x)z无关,那么:logp(x)=zq(zx)logp(x)dz(公式一)依据联合概率公式p(x)=p(zx)p(x,z)=q(zx)p(x,z)p(zx)q(zx)代入(公式一)logp(x)=zq(zx)log(q(zx)p(x,z)p(zx)q(zx))dz我们将log里的乘积拆开,变为两项之和=zq(zx)log(q(zx)p(x,z))dz+zq(zx)log(p(zx)q(zx))dz
结合KL散度公式,我们可以看出第二项其实就是 K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) KL(q(z|x) || p(z|x)) KL(q(zx)∣∣p(zx))。因为该值为非负项,所以 l o g p ( x ) logp(x) logp(x)不可能小于第一项,我们使用 L b Lb Lb来代表第一项。
l o g p ( x ) = L b + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) p ( z ∣ x ) ) d z = L b + K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 结合式式子 : p ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z 分析如下: 当 p ( x ∣ z ) 不变时, p ( x ) 也不变,从而 l o g p ( x ) 也不变,那么 L b + K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 的值就不会变。 这时如果我们利用 q ( z ∣ x ) 来最大化 L b ,那么 L b 就会增大,而 K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 的值就会减小。 当我们调节到 q ( z ∣ x ) 与 p ( z ∣ x ) 完全相同时, K L 散度就为 0 , L b 和 l o g p ( x ) 完全一致 logp(x)=Lb+\int_zq(z|x)log(\frac{q(z|x)}{p(z|x)})dz\\ =Lb+KL(q(z|x) || p(z|x))\\ 结合式式子:p(x)=\int_zp(x|z)p(z)dz\\ 分析如下:\\ 当p(x|z)不变时,p(x)也不变,从而log p(x)也不变,那么Lb+KL(q(z|x) || p(z|x))的值就不会变。\\ 这时如果我们利用q(z|x)来最大化Lb,那么Lb就会增大,而KL(q(z|x) || p(z|x))的值就会减小。\\ 当我们调节到q(z|x)与p(z|x)完全相同时,KL散度就为0,Lb和logp(x)完全一致\\ logp(x)=Lb+zq(zx)log(p(zx)q(zx))dz=Lb+KL(q(zx)∣∣p(zx))结合式式子:p(x)=zp(xz)p(z)dz分析如下:p(xz)不变时,p(x)也不变,从而logp(x)也不变,那么Lb+KL(q(zx)∣∣p(zx))的值就不会变。这时如果我们利用q(zx)来最大化Lb,那么Lb就会增大,而KL(q(zx)∣∣p(zx))的值就会减小。当我们调节到q(zx)p(zx)完全相同时,KL散度就为0Lblogp(x)完全一致
在这里插入图片描述

那么如果 q ( z ∣ x ) 不变呢?此时当我们增大 p ( x ∣ z ) 时, L b 会增大且 p ( x ) 会增大,即 l o g p ( x ) 也会增大。 由此我们可以得出结论,只要我们最大化 L b 就能使 l o g p ( x ) 最大化。 那么如果q(z|x)不变呢?此时当我们增大p(x|z)时,Lb会增大且p(x)会增大,即log p(x)也会增大。\\ 由此我们可以得出结论,只要我们最大化Lb就能使log p(x)最大化。 那么如果q(zx)不变呢?此时当我们增大p(xz)时,Lb会增大且p(x)会增大,即logp(x)也会增大。由此我们可以得出结论,只要我们最大化Lb就能使logp(x)最大化。
此时我们的优化目标就变成了最大化 L b Lb Lb
L b = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( p ( x ∣ z ) p ( z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g p ( z ) q ( z ∣ x ) d z + ∫ z q ( z ∣ x ) l o g p ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) + ∫ z q ( z ∣ x ) l o g p ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) + E q ( z ∣ x ) l o g p ( x ∣ z ) 我们加负号,转化为最小化问题: L o s s = K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) − E q ( z ∣ x ) l o g p ( x ∣ z ) Lb=\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)})dz\\ =\int_zq(z|x)log(\frac{p(x|z)p(z)}{q(z|x)})dz\\ =\int_zq(z|x)log\frac{p(z)}{q(z|x)}dz+\int_zq(z|x)logp(x|z)dz\\ =-KL(q(z|x)||p(z)) + \int_zq(z|x)logp(x|z)dz\\ =-KL(q(z|x)||p(z)) + E_{q(z|x)}logp(x|z)\\ 我们加负号,转化为最小化问题:\\ Loss=KL(q(z|x)||p(z)) - E_{q(z|x)}logp(x|z) Lb=zq(zx)log(q(zx)p(x,z))dz=zq(zx)log(q(zx)p(xz)p(z))dz=zq(zx)logq(zx)p(z)dz+zq(zx)logp(xz)dz=KL(q(zx)∣∣p(z))+zq(zx)logp(xz)dz=KL(q(zx)∣∣p(z))+Eq(zx)logp(xz)我们加负号,转化为最小化问题:Loss=KL(q(zx)∣∣p(z))Eq(zx)logp(xz)

  • 此时VAE的最终目标就一目了然了,VAE的训练目标有两个:
    • 第一,最大化在 q ( z ∣ x ) q(z|x) q(zx)这个分布下 l o g p ( x ∣ z ) logp(x|z) logp(xz)的期望(L1损失),其中 q ( z ∣ x ) q(z|x) q(zx)为编码器输入 x x x时产生 z z z的概率。假设解码器利用 z z z生成出了 x ’ x’ x,我们就需要使 x ’ x’ x尽可能向 x x x靠近,以最大化 l o g p ( x ∣ z ) logp(x|z) logp(xz)
    • 第二,最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(zx)∣∣p(z))(L2损失),使 q ( z ∣ x ) q(z|x) q(zx)的分布尽量向 p ( z ) p(z) p(z)靠近;

根据上述的两个训练目标,VAE的损失函数也被设计为两个:

  • L1损失目的是输出的 x ’ x’ x尽可能向原始 x x x靠近(即重构损失),我们可以最小化x’和x之间的MSE Loss或者BCE Loss。

  • L2用于最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(zx)∣∣p(z)),VAE假设 q ( z ∣ x ) q(z|x) q(zx)的分布为正态分布,而 p ( z ) p(z) p(z)为标准正态分布。计算两个正态分布之间的KL散度的公式如下(这里就不推导了,直接给出):

    在这里插入图片描述

    由于此处 p ( z ) p(z) p(z)为标准正态分布,因此其 μ 2 μ_2 μ2为0, σ 2 σ_2 σ2为1,那么我们带入后可得
    K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = − 1 2 ( l o g σ 2 − σ 2 − u 2 + 1 ) 其中 σ 为 q ( z ∣ x ) 的标准差, μ 为 q ( z ∣ x ) 的均值 KL(q(z|x)||p(z))=-\frac{1}{2}(log\sigma^2-\sigma^2-u^2+1) \\ 其中σ为q(z|x)的标准差,μ为q(z|x)的均值 KL(q(zx)∣∣p(z))=21(logσ2σ2u2+1)其中σq(zx)的标准差,μq(zx)的均值

2.4.3 VAE的重参数化

蓝色节点为采样节点,左侧由于采样的存在无法对 μ μ μ σ σ σ求导,反向传播无法进行,而右侧由于采用的重参数化技巧,灰色节点全部打通,使得网络能够正常进行反向传播。

在这里插入图片描述

2.4.4 VAE和GAN的区别、VAE的本质

可以参考:变分自编码器(一):原来是这么一回事 - 科学空间

2.5 VAE在MNIST数据集上的应用

  • 解码器输出的不是方差 σ 2 σ^2 σ2 ,而是对数方差 l o g σ 2 logσ^2 logσ2,详见下面代码的encoder函数,这么做的原因就是,神经网络的输出是 [−∞,+∞] 的任意数值,但是方差不可能为负数,所以对方差取对数以满足神经网络输出值域的要求。
  • 重参数化详见reparameter函数。
  • 训练时候不要忘了损失函数的KL散度部分,利用上面推导的公式计算。这里由于模型的输出是对数方差 l o g σ 2 logσ^2 logσ2而不是方差,所以原始的计算公式需要做一个转换,如下:

K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = − 1 2 ( l o g σ 2 − σ 2 − u 2 + 1 ) 【其中 σ 为 q ( z ∣ x ) 的标准差, μ 为 q ( z ∣ x ) 的均值】 = − 1 2 ( l o g _ v a r − e l o g _ v a r − m u 2 + 1 ) KL(q(z|x)||p(z))\\ =-\frac{1}{2}(log\sigma^2-\sigma^2-u^2+1)【其中σ为q(z|x)的标准差,μ为q(z|x)的均值】 \\ =-\frac{1}{2}(log\_var-e^{log\_var}-mu^2+1) KL(q(zx)∣∣p(z))=21(logσ2σ2u2+1)【其中σq(zx)的标准差,μq(zx)的均值】=21(log_varelog_varmu2+1)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_imageclass VAE(nn.Module):def __init__(self):super().__init__()# encoderself.encoder_layer = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),)self.fc1 = nn.Linear(128, 10)  # 均值self.fc2 = nn.Linear(128, 10)  # log方差# decoderself.decoder = nn.Sequential(nn.Linear(10, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 784),)def encoder(self, x):x = F.relu(self.encoder_layer(x))  # [b,784] ———> [b,128]mu = self.fc1(x)                   # [b,128] ———> [b,10]log_var = self.fc2(x)              # [b,128] ———> [b,10]return mu, log_vardef reparameter(self, mu, log_var):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 重参数化技巧 z = mu(均值) + eps(0-1正态分布) *  sigma(方差)std = torch.sqrt(torch.exp(log_var))eps = torch.rand_like(std)z = mu + eps * std     # [b, 10]return zdef forward(self, x):x = x.view(-1, 28 * 28)                # [b,1,28,28] ———> [b,784]mu, log_var = self.encoder(x)          # [b,784] ———> [b,10]  mu  sigmaz = self.reparameter(mu, log_var)      # [b,10] ———> [b,10]x = self.decoder(z)                    # [b,10] ———> [b,784]x_hat = x.reshape(-1, 1, 28, 28)return x_hat, mu, log_vardef train(model, epoch=200):for epoch in range(epoch):model.train()total_loss = 0.0total_bce_loss = 0.0total_kl_loss = 0.0for x, y in trian_dl:x, y = x.to(device), y.to(device)x_hat, mu, log_var = model(x)# 3-1、bce_lossbce_loss = F.binary_cross_entropy(torch.sigmoid(x_hat.view(-1, 784)), x.view(-1, 784), reduction='sum')# 3-2、kl_losskl_loss = torch.sum(-0.5 * (log_var - torch.exp(log_var) - mu ** 2 + 1))loss = bce_loss + kl_lossopt.zero_grad()loss.backward()opt.step()total_loss += losstotal_bce_loss += bce_losstotal_kl_loss += kl_lossprint(f'epoch = {epoch + 1},bce_loss = {(total_bce_loss / len(trian_dl)):.4f}, kl_loss = {(total_kl_loss / len(trian_dl)):.4f}, train loss = {total_loss / len(trian_dl):.4f}')torch.save(model.state_dict(), 'vae_model.pth')if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 1、读取数据集trian_dl = torch.utils.data.DataLoader(datasets.MNIST("/root/autodl-fs/data/minist",train=True,download=False,transform=transforms.Compose([transforms.ToTensor(),]),),batch_size=256,shuffle=True,num_workers=12)# 2、创建VAE模型及优化器model = VAE().to(device)opt = torch.optim.Adam(model.parameters(), lr=0.001)# 4、模型训练train(model, epoch=300)# 5、模型推理# 加载模型model.load_state_dict(torch.load('vae_model.pth', map_location=device))model.eval()z = torch.randn(3, 10)  # 生成一个形状为 (3, 10) 的随机数张量image = model.decoder(z).reshape(-1, 1, 28, 28)save_image(image, "vae_image_random.png")

在这里插入图片描述

参考:
图片来自李宏毅老师的教程视频:https://www.bilibili.com/video/av15889450/?p=33
https://kexue.fm/archives/5253
https://blog.csdn.net/weixin_42491648/article/details/132384913
http://www.gwylab.com/note-vae.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/333907.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【智能家居入门1】环境信息监测(STM32、ONENET云平台、微信小程序、HTTP协议)

作为入门本篇只实现微信小程序接收下位机上传的数据,之后会持续发布如下项目:①可以实现微信小程序控制下位机动作,真正意义上的智能家居;②将网络通讯协议换成MQTT协议再实现上述功能,此时的服务器也不再是ONENET&…

数据结构—队列(C语言实现)

文章目录 前言一、队列的概念二、队列的实现Queue.hQueue.c 三、设计循环队列问题数组实现链表实现 总结 前言 嗨喽喽!!小伙伴们,大家好哇,欢迎来到我的博客! 今天将要分享的是另一种数据结构—队列,以及…

五分钟搭建一个Suno AI音乐站点

五分钟搭建一个Suno AI音乐站点 在这个数字化时代,人工智能技术正以惊人的速度改变着我们的生活方式和创造方式。音乐作为一种最直接、最感性的艺术形式,自然也成为了人工智能技术的应用场景之一。今天,我们将以Vue和Node.js为基础&#xff…

MySQL触发器实战:自动执行的秘密

欢迎来到我的博客,代码的世界里,每一行都是一个故事 🎏:你只管努力,剩下的交给时间 🏠 :小破站 MySQL触发器实战:自动执行的秘密 前言触发器的定义和作用触发器的定义和作用触发器的…

leetCode.82. 删除排序链表中的重复元素 II

leetCode.82. 删除排序链表中的重复元素 II 题目思路: 代码 class Solution { public:ListNode* deleteDuplicates(ListNode* head) {auto dummy new ListNode(-1);dummy->next head;auto p dummy;while(p->next){auto q p->next->next;while(q …

插件“猫抓”使用方法 - 浏览器下载m3u8视频 - 合并 - 视频检测下载 - 网课下载神器

前言 浏览器下载m3u8视频 - 合并 - 网课下载神器 chrome插件-猫抓 https://chrome.zzzmh.cn/info/jfedfbgedapdagkghmgibemcoggfppbb 步骤: P.s. 推荐大佬的学习视频! 《WEB前端大师课》超级棒! https://ke.qq.com/course/5892689#term_id…

使用Python操作Jenkins

大家好,Python作为一种简洁、灵活且功能丰富的编程语言,可以与各种API轻松集成,Jenkins的API也不例外。借助于Python中的python-jenkins模块,我们可以轻松地编写脚本来连接到Jenkins服务器,并执行各种操作,…

C语言中的位段

位段是通过结构体实现的,可以在一定程度上减小空间浪费,位段的声明和结构体类似,有以下几个不同: ①位段的成员必须是整形(int,char,short等)。 ②成员后边有冒号和数字,表示该成员占几个bit位…

【译】MySQL复制入门: 探索不同类型的MySQL复制解决方案

原文地址:An Introduction to MySQL Replication: Exploring Different Types of MySQL Replication Solutions 在这篇博文中,我将深入介绍 MySQL 复制,回答它是什么、如何工作、它的优势和挑战,并回顾作为 MySQL 环境&#xff0…

“智能体时代:探索无限可能——零代码构建智能教练智能体“

随着智能体技术的飞速发展,各个领域正经历着空前的变革和新的发展机遇。作为人工智能的一个关键组成部分,智能体以其自我驱动、智能响应和适应能力,逐渐深入到我们日常生活的各个层面,成为促进社会发展和科技进步的新引擎。 顺应这…

深度神经网络——贝叶斯与朴素贝叶斯定理

概述 贝叶斯定理是概率论中一个非常重要的概念,它提供了一种在已知某些相关事件的概率时,计算另一个事件发生概率的方法。在你提供的内容中,贝叶斯定理被描述为一种“魔法”,因为它能够使计算机通过分析大量的数据来预测人们可能…

十四天学会Vue——Vue核心(理论+实战)中篇(第二天)

声明:是接着上篇讲的哦,感兴趣可以去看一看~ 这里一些代码就不写了,为了缩减代码量,大家知道就可以了: Vue.config.productionTip false //阻止 vue 在启动时生成生产提示。热身小tips,可以安装这个插件&…

【LeetCode】【9】回文数(1047字)

文章目录 [toc]题目描述样例输入输出与解释样例1样例2样例3 提示进阶Python实现 个人主页:丷从心 系列专栏:LeetCode 刷题指南:LeetCode刷题指南 题目描述 给一个整数x,如果x是一个回文整数,返回true;否…

MIT6.828 Lab2-1 Using gdb

Using gdb gdb使用: xv6 gdb调试方法 问题1: Looking at the backtrace output, which function called syscall? 按照提示开启gdb后键入: b syscall c layout src backtrace输出结果: (gdb) backtrace #0 syscall () at k…

Python + adb 实现打电话功能

前言 其实很多年前写过一篇python打电话的功能,链接如下: Python twilio 实现打电话和发短信功能_自动发短信代码-CSDN博客 今天由于工作需要,又用python写了个关于打电话的小工具,主要是通过ADB方式实现的 实现过程 1.先利用…

YOLOv8+PyQt5鸟类检测系统完整资源集合(yolov8模型,从图像、视频和摄像头三种路径识别检测,包含登陆页面、注册页面和检测页面)

资源包含可视化的鸟类检测系统,基于最新的YOLOv8训练的鸟类检测模型,和基于PyQt5制作的可视化鸟类检测系统,包含登陆页面、注册页面和检测页面,该系统可自动检测和识别图片或视频当中出现的各种鸟类,以及自动开启摄像头…

Putty: 随心御剑——远程启动服务工具plink

一、引言:如何远程控制 也许你会有这样的场景,交互程序(以下简称UI程序)跑在windows端,而控制程序跑在Linux上。我们想要通过windows端 UI程序来启动Linux下面的服务,来一场酣畅淋漓的御剑飞行咋办,难道要自己十年磨一剑,在Linux下编写一个受控服务程序么.计算机科技发…

如何创建一个vue项目?详细教程,如何创建第一个vue项目?

已经安装node.js在自己找的到的地方新建一个文件夹用于存放项目,记住文件夹的存放路径,以我为例,我的文件夹路径为D:\tydic 打开cmd命令窗口,进入刚刚的新建文件夹 切换硬盘: D: 进入文件夹:cd tydic 使…

重学java 49 List接口

但逢良辰,顺颂时宜 —— 24.5.28 一、List接口 1.概述: 是collection接口的子接口 2.常见的实现类: ArrayList LinkedList Vector 二、List集合下的实现类 1.ArrayList集合的使用及源码分析 1.概述 ArrayList是List接口的实现类 2.特点 a.元素有序 —> 按照什么顺…

红外成像人员检测数据集VOC+YOLO格式5838张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):5838 标注数量(xml文件个数):5838 标注数量(txt文件个数):5838 标注…