变分自编码器(VAE)PyTorch Lightning 实现

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • VAE 简介
      • 基本原理
      • 应用与优点
      • 缺点与挑战
    • 使用 VAE 生成 MNIST 手写数字
      • 忽略警告
      • 导入必要的库
      • 设置随机种子
      • cuDNN 设置
      • 超参数设置
      • 数据加载
      • 定义 VAE 模型
      • 定义损失函数
      • 定义 Lightning 模型
      • 训练模型
      • 绘制训练过程
      • 随机生成新样本
      • 根据潜变量插值生成新样本


VAE 简介

变分自编码器(Variational Autoencoder,VAE)是一种深度学习中的生成模型,它结合了自编码器(Autoencoder, AE)和概率建模的思想,在无监督学习环境中表现出了强大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成为生成模型领域的重要组成部分。

基本原理

自编码器(AE)基础:
自编码器是一种神经网络结构,通常由两部分组成:编码器(Encoder)和解码器(Decoder)。原始数据通过编码器映射到一个低维的潜在空间(或称为隐空间),这个低维向量被称为潜变量(latent variable)。然后,潜变量再通过解码器重构回原始数据的近似版本。在训练过程中,自编码器的目标是使得输入数据经过编码-解码过程后能够尽可能地恢复原貌,从而学习到数据的有效表示。

VAE的引入与扩展:
VAE 将自编码器的概念推广到了概率框架下。在 VAE 中,潜变量不再是确定性的,而是被赋予了概率分布。具体来说,对于给定的输入数据,编码器不直接输出一个点估计值,而是输出潜变量的均值和方差(假设潜变量服从高斯分布)。这样,每个输入数据可以被视为是从某个潜在的概率分布中采样得到的。

变分推断(Variational Inference):
训练 VA E时,由于真实的后验概率分布难以直接计算,因此采用变分推断来近似后验分布。编码器实际上输出的是一个参数化的概率分布 q ( z ∣ x ) q(z|x) q(zx),即给定输入 x x x 时潜变量 z z z 的概率分布。然后通过最小化 KL 散度(Kullback-Leibler divergence)来优化这个近似分布,使其尽可能接近真实的后验分布 p ( z ∣ x ) p(z|x) p(zx)

目标函数 - Evidence Lower Bound (ELBO):
VAE 的目标函数是证据下界(ELBO),它是原始数据 log-likelihood 的下界。优化该目标函数既鼓励编码器找到数据的高效潜在表示,又促使解码器基于这些表示重建出类似原始数据的新样本。

数学表达上,ELBO 通常分解为两个部分:

  1. 重构损失(Reconstruction Loss):衡量从潜变量重构出来的数据与原始数据之间的差异。
  2. KL散度损失(KL Divergence Loss):衡量编码器产生的潜变量分布与预设的标准正态分布(或其他先验分布)之间的距离。

应用与优点

  • VAE 可以用于生成新数据,例如图像、文本、音频等。
  • 由于其对潜变量进行概率建模,所以它可以提供连续的数据生成,并且能够探索数据的不同模式。
  • 在处理连续和离散数据时具有一定的灵活性。
  • 可以用于特征学习,提取数据的有效低维表示。

缺点与挑战

  • 训练 VAE 可能需要大量的计算资源和时间。
  • 生成的样本有时可能不够清晰或细节模糊,尤其是在复杂数据集上。
  • 对于某些复杂的分布形式,VAE 可能无法完美捕获所有细节。

使用 VAE 生成 MNIST 手写数字

下面我们将使用 PyTorch Lightning 来实现一个简单的 VAE 模型,并使用 MNIST 数据集来进行训练和生成。

在线 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist

忽略警告

import warnings
warnings.filterwarnings("ignore")

导入必要的库

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasetsimport lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

设置随机种子

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cuDNN 设置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

超参数设置

batch_size = 64epochs = 10
KLD_weight = 1
lr = 0.001input_dim = 784  # 28 * 28
h_dim = 256  # 隐藏层维度  
z_dim = 2  # 潜变量维度

数据加载

train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

定义 VAE 模型

class VAE(nn.Module):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.input_dim = input_dimself.h_dim = h_dimself.z_dim = z_dim# Encoderself.fc1 = nn.Linear(input_dim, h_dim)self.fc21 = nn.Linear(h_dim, z_dim)  # muself.fc22 = nn.Linear(h_dim, z_dim)  # log_var# Decoderself.fc3 = nn.Linear(z_dim, h_dim)self.fc4 = nn.Linear(h_dim, input_dim)def encode(self, x):h = torch.relu(self.fc1(x))mean = self.fc21(h)log_var = self.fc22(h)return mean, log_vardef reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = torch.relu(self.fc3(z))out = torch.sigmoid(self.fc4(h))return outdef forward(self, x):mean, log_var = self.encode(x)z = self.reparameterize(mean, log_var)reconstructed_x = self.decode(z)return reconstructed_x, mean, log_varvae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])

定义损失函数

def loss_function(x_hat, x, mu, log_var, KLD_weight=1):BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重构损失KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度损失loss = BCE_loss + KLD_loss * KLD_weightreturn loss, BCE_loss, KLD_loss

定义 Lightning 模型

class LitModel(pl.LightningModule):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super().__init__()self.model = VAE(input_dim, h_dim, z_dim)def forward(self, x):x = self.model(x)return xdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)reconstructed_x, mean, log_var = self(x)loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log_dict({"BCE_loss": BCE_loss,"KLD_loss": KLD_loss,},on_step=False,on_epoch=True,logger=True,)return lossdef decode(self, z):out = self.model.decode(z)return out

训练模型

model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(max_epochs=epochs,enable_progress_bar=True,logger=logger,callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)

绘制训练过程

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()

训练过程

随机生成新样本

row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()plt.figure(figsize=(col, row))
for i in range(row * col):plt.subplot(row, col, i + 1)plt.imshow(random_res[i].squeeze(), cmap="gray")plt.xticks([])plt.yticks([])plt.axis("off")
plt.show()

随机生成新样本

根据潜变量插值生成新样本

from scipy.stats import normn = 15
digit_size = 28grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):for j, xi in enumerate(grid_x):t = [xi, yi]z_sampled = torch.FloatTensor(t)with torch.no_grad():decode = model.decode(z_sampled)digit = decode.view((digit_size, digit_size))figure[i * digit_size : (i + 1) * digit_size,j * digit_size : (j + 1) * digit_size,] = digitplt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()

根据潜变量插值生成新样本

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/260634.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

290. Word Pattern(单词规律)

题目描述 给定一种规律 pattern 和一个字符串 s &#xff0c;判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配&#xff0c;例如&#xff0c; pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连接的对应规律。 提示: 1 < pattern.length < 300 pa…

安装VMware+安装Linux

以上就是VMware在安装时的每一步操作&#xff0c;基本上就是点击 "下一步" 一直进行安装 安装Linux VMware虚拟机安装完毕之后&#xff0c;我们就可以打开VMware&#xff0c;并在上面来安装Linux操作系统。具体步骤如下&#xff1a; 1). 选择创建新的虚拟机 2). 选…

C. LR-remainders

思路&#xff1a;正着暴力会tle&#xff0c;所以我们可以逆着来。 代码&#xff1a; #include<bits/stdc.h> #define int long long #define x first #define y second #define endl \n #define pq priority_queue using namespace std; typedef pair<int,int> p…

如何确定分库还是 分表?

分库分表 分库分表使用的场景不一样&#xff1a; 分表因为数据量比较大&#xff0c;导致事务执行缓慢&#xff1b;分库是因为单库的性能无法满足要求。 分片策略 1、垂直拆分 水平拆分 3 范围分片&#xff08;range&#xff09; 垂直水平拆分 4 如何解决数据查询问题&a…

[计算机网络]---UDP协议

前言 作者&#xff1a;小蜗牛向前冲 名言&#xff1a;我可以接受失败&#xff0c;但我不能接受放弃 如果觉的博主的文章还不错的话&#xff0c;还请点赞&#xff0c;收藏&#xff0c;关注&#x1f440;支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、端口号…

老师不能有副业吗为什么

当老师走进教室&#xff0c;他们的每一句话、每一个动作都可能影响着几十上百个孩子的未来。这样的责任&#xff0c;难道是可以轻易分担的吗&#xff1f; 或许有人会说&#xff0c;老师为什么不能有副业&#xff1f;他们也有自己的生活&#xff0c;也需要经济支持。确实&#…

鸿蒙-基于ArkTS声明式开发的简易备忘录,适合新人学习,可用于大作业

本文地址&#xff1a;https://blog.csdn.net/qq_40785165/article/details/136161182?spm1001.2014.3001.5502&#xff0c;转载请附上此链接 大家好&#xff0c;我是小黑&#xff0c;一个还没秃头的程序员~~~ 不知不觉已经有很长一段时间没有分享过自己写的东西了&#xff0…

使用Postman拦截浏览器请求

项目上线之后&#xff0c;难免会有BUG。在出现问题的时候&#xff0c;我们可能需要获取前端页面发送请求的数据&#xff0c;然后在测试环境发送相同的数据将问题复现。手动构建数据是挺麻烦的一件事&#xff0c;所以我们可以借助Postman在浏览器上的插件帮助拦截请求&#xff0…

基于微信小程序的校园跑腿系统的研究与实现,附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

第一件事 什么是 Java 虚拟机 (JVM)

1、什么是虚拟机&#xff1f; - 这个其实是一个挺逗的事情&#xff0c;说白了&#xff0c;就是基于某个硬件架构&#xff0c;在这个硬件部署了一个操作系统&#xff0c;再构架一层虚拟的操作系统&#xff0c;这个新构架的操作系统就是虚拟机。 不知道的兄弟姐妹们&#xff0c;…

[word] 怎么把word表格里的字放在正中间? #职场发展#知识分享#知识分享

怎么把word表格里的字放在正中间&#xff1f; word表格中文字在中间的处理方式如下&#xff1a; 1、在表格中选择需要居中的文字的单元格&#xff0c;具体如下图。 2、全选后&#xff0c;鼠标在工具栏中找到&#xff1a;对齐方式&#xff0c;点击它后面的倒三角&#xff0c;如…

头部新势力新车型将全系标配!4D成像雷达元年真来了?

4D成像雷达赛道又热闹起来了。 自2023年2月&#xff0c;森思泰克2片级联4D成像雷达STA77-6全球首发量产车型——理想L7正式发布上市&#xff0c;立下了国产4D成像雷达产品在乘用车前装量产的重要里程碑事件&#xff0c;业界普遍认为2023年将迎来4D成像雷达规模化量产元年。 尽…

重复导航到当前位置引起的。Vue Router 提供了一种机制,阻止重复导航到相同的路由路径。

代码&#xff1a; <!-- 侧边栏 --><el-col :span"12" :style"{ width: 200px }"><el-menu default-active"first" class"el-menu-vertical-demo" select"handleMenuSelect"><el-menu-item index"…

江淮瑞风RF8强势出圈,瀚思通与华为联手打造智能MPV新标杆

1月31日&#xff0c;江淮瑞风RF8正式上市&#xff0c;新车定位为“新国潮智能电混MPV”&#xff0c;引发市场高度关注。 新车共推出4款配置车型&#xff0c;售价区间为16.99-23.99万元。该车型基于中国品牌首个MPV专属架构—江淮瑞风 MUSE 共创智电架构打造。智能化层面&#x…

SQL数据库基础语法-增删改

SQL数据库基础语法-增删改 数据库是 ​ “按照数据结构来组织、存储和管理数据的仓库”。是一个长期存储在计算机内的、有组织的、可共享的、统一管理的大量数据的集合。 GeekSec专注技能竞赛培训5年&#xff0c;包含网络建设与运维和信息安全管理与评估两大赛项&#xff0c;…

【MySQL】Navicat/SQLyog连接Ubuntu中的数据库(MySQL)

&#x1f3e1;浩泽学编程&#xff1a;个人主页 &#x1f525; 推荐专栏&#xff1a;《深入浅出SpringBoot》《java对AI的调用开发》 《RabbitMQ》《Spring》《SpringMVC》 &#x1f6f8;学无止境&#xff0c;不骄不躁&#xff0c;知行合一 文章目录 前言一、安装…

手撕C语言习题

定义一个表示公交线路的结构体&#xff0c;要求有线路名称(例如 616)&#xff0c;起始站&#xff0c;终点站&#xff0c;里程等成员&#xff0c; 定义结构体数组&#xff0c;用来存储多条条公交线路信息&#xff0c;要求能够输出从指定起始站发车的所以公交线路信息。 2、定义…

rpm安装gitlab

1.1 下载gitlab安装包 使用rpm包安装命令安装gitlab的rpm包&#xff0c;下载地址为https://packages.gitlab.com/gitlab/gitlab-ce社区版本&#xff1b; 推荐使用清华大学镜像&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/gitlab安装包详见&#xff1…

[Android]Frida-hook环境配置

准备阶段 反编译工具:Jadx能够理解Java语言能编写小型的JavaScript代码连接工具:adb设备:Root的安卓机器&#xff0c;或者模拟器 Frida&#xff08;https://frida.re/&#xff09; 就像是你计算机或移动设备的妙妙工具。它帮助你查看其他程序或应用内部发生的事情&#xff0…

鸿蒙系统优缺点,能否作为开发者选择

凡是都有对立面&#xff0c;就直接说说鸿蒙的优缺点吧。 鸿蒙的缺点&#xff1a; 鸿蒙是从2019年开始做出来的&#xff0c;那时候是套壳Android大家都知晓。从而导致大家不看鸿蒙系统&#xff0c;套壳Android就是多次一举。现在鸿蒙星河版已经是纯血鸿蒙&#xff0c;但是它的…