pytorch实现变分自编码器

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution),生成与输入数据相似的新样本。VAE 可以用于数据生成、降维、异常检测等任务。

VAE 的关键思想是在传统的自编码器(Autoencoder)的基础上,引入了变分推断(Variational Inference)和概率模型,使得网络能够学习到数据的潜在分布,而不仅仅是数据的映射。

VAE 的结构:

  1. 编码器(Encoder):将输入数据映射到潜在空间的分布。不同于传统的自编码器直接将数据映射到一个固定的潜在向量,VAE 通过输出潜在变量的均值和方差来描述一个概率分布,这样潜在空间中的每个点都有一个概率分布。
  2. 潜在空间(Latent Space):表示数据的潜在特征。在 VAE 中,潜在空间的表示是一个分布而不是固定的值。通常,采用正态分布来作为潜在空间的先验分布。
  3. 解码器(Decoder):从潜在空间的样本中重构输入数据。解码器通过将潜在空间的点映射回数据空间来生成样本。

VAE 的目标函数:

VAE 的目标是最大化变分下界(Variational Lower Bound,简称 ELBO),即通过优化以下两部分的加权和:

  • 重构误差(Reconstruction Loss):衡量生成的数据和输入数据之间的差异,通常使用均方误差(MSE)或交叉熵(Cross-Entropy)。
  • KL 散度(KL Divergence):衡量潜在空间的分布与先验分布(通常是标准正态分布)之间的差异。

其最终的目标是使生成的数据尽可能接近真实数据,同时使潜在空间的分布接近先验分布。

优点:

  • VAE 能够生成具有多样性的样本,尤其适用于图像、音频等数据的生成。
  • 潜在空间通常具有良好的结构,可以进行插值、样本生成等操作。

应用:

  • 生成任务:如图像生成、文本生成等。
  • 数据重构:如去噪、自编码等。
  • 半监督学习:VAE 可以结合有标签和无标签的数据进行训练,提升模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt# 生成圆形图像的函数(使用PyTorch)
def generate_circle_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像center = size // 2radius = size // 4for y in range(size):for x in range(size):if (x - center) ** 2 + (y - center) ** 2 <= radius ** 2:image[0, y, x] = 1  # 在圆内的点设置为白色return image# 生成方形图像的函数(使用PyTorch)
def generate_square_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像padding = size // 4image[0, padding:size - padding, padding:size - padding] = 1  # 设置方形区域为白色return image# 自定义数据集:圆形和方形图像
class ShapeDataset(Dataset):def __init__(self, num_samples=1000, size=64):self.num_samples = num_samplesself.size = sizeself.data = []# 生成数据:一半是圆形图像,一半是方形图像for i in range(num_samples // 2):self.data.append(generate_circle_image(size))self.data.append(generate_square_image(size))def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx].float()  # 直接返回 PyTorch Tensor 格式的数据# VAE模型定义
class VAE(nn.Module):def __init__(self, latent_dim=2):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.fc1 = nn.Linear(64 * 64, 400)self.fc21 = nn.Linear(400, latent_dim)  # 均值self.fc22 = nn.Linear(400, latent_dim)  # 方差# 解码器self.fc3 = nn.Linear(latent_dim, 400)self.fc4 = nn.Linear(400, 64 * 64)def encode(self, x):h1 = torch.relu(self.fc1(x.view(-1, 64 * 64)))return self.fc21(h1), self.fc22(h1)  # 返回均值和方差def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h3 = torch.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3)).view(-1, 1, 64, 64)  # 重构图像def forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 损失函数:重构误差 + KL 散度
def loss_function(recon_x, x, mu, logvar):BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 64 * 64), x.view(-1, 64 * 64), reduction='sum')# KL 散度return BCE + 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar)# 设置超参数
batch_size = 128
epochs = 10
latent_dim = 2
learning_rate = 1e-3# 数据加载
train_loader = DataLoader(ShapeDataset(num_samples=2000), batch_size=batch_size, shuffle=True)# 创建模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
def train(epoch):model.train()train_loss = 0for batch_idx, data in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()recon_batch, mu, logvar = model(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.6f}')print(f'Train Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')# 测试并显示一些真实图像和生成的图像
def test():model.eval()with torch.no_grad():# 获取一批真实的图像(原始图像)real_images = next(iter(train_loader))[:64]  # 只取前64个图像real_images = real_images.cpu().numpy()# 从潜在空间随机生成一些样本sample = torch.randn(64, latent_dim).to(device)generated_images = model.decode(sample).cpu().numpy()# 显示真实图像和生成的图像,分别标明fig, axes = plt.subplots(8, 8, figsize=(8, 8))axes = axes.flatten()for i in range(64):if i < 32:  # 前32个显示真实图像axes[i].imshow(real_images[i].squeeze(), cmap='gray')axes[i].set_title('Real', fontsize=8)else:  # 后32个显示生成图像axes[i].imshow(generated_images[i - 32].squeeze(), cmap='gray')axes[i].set_title('Generated', fontsize=8)axes[i].axis('off')plt.tight_layout()plt.show()# 训练模型
for epoch in range(1, epochs + 1):train(epoch)# 训练完成后,显示生成的图像
test()

解释:

  1. 真实图像 (real_images):我们通过 next(iter(train_loader)) 获取一批真实图像,并将其转换为 NumPy 数组,以便 matplotlib 显示。
  2. 生成图像 (generated_images):通过模型生成的图像,使用 decode() 方法生成潜在空间的样本。
  3. 图像展示:前 32 张图像展示真实图像,后 32 张图像展示生成的图像。每个图像上方都有 RealGenerated 标注。

结果:

  • 前32个图像:显示真实图像,并标注为 Real
  • 后32个图像:显示通过训练后的 VAE 生成的图像,并标注为 Generated

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12068.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在线知识库的构建策略提升组织信息管理效率与决策能力

内容概要 在线知识库作为现代企业信息管理的重要组成部分&#xff0c;具有显著的定义与重要性。它不仅为组织提供了一个集中存储与管理知识的平台&#xff0c;还能够有效提升信息检索的效率&#xff0c;促进知识的创新和利用。通过这样的知识库&#xff0c;企业可以更好地应对…

【汽车电子软件架构】AutoSAR从放弃到入门专栏导读

本文是汽车电子软件架构&#xff1a;AutoSAR从放弃到入门专栏的导读篇。文章延续专栏文章的一贯作风&#xff0c;从概念与定义入手&#xff0c;希望读者能对AutoSAR架构有一个整体的认识&#xff0c;然后对专栏涉及的文章进行分类与链接。本文首先从AutoSAR汽车软件架构的概念&…

DeepSeek-R1:通过强化学习激励大型语言模型(LLMs)的推理能力

摘要 我们推出了第一代推理模型&#xff1a;DeepSeek-R1-Zero和DeepSeek-R1。DeepSeek-R1-Zero是一个未经监督微调&#xff08;SFT&#xff09;作为初步步骤&#xff0c;而是通过大规模强化学习&#xff08;RL&#xff09;训练的模型&#xff0c;展现出卓越的推理能力。通过强…

响应式编程与协程

响应式编程与协程的比较 响应式编程的弊端虚拟线程Java线程内核线程的局限性传统线程池的demo虚拟线程的demo 响应式编程的弊端 前面用了几篇文章介绍了响应式编程&#xff0c;它更多的使用少量线程实现线程间解耦和异步的作用&#xff0c;如线程的Reactor模型&#xff0c;主要…

本地部署DeepSeek-R1模型(新手保姆教程)

背景 最近deepseek太火了&#xff0c;无数的媒体都在报道&#xff0c;很多人争相着想本地部署试验一下。本文就简单教学一下&#xff0c;怎么本地部署。 首先大家要知道&#xff0c;使用deepseek有三种方式&#xff1a; 1.网页端或者是手机app直接使用 2.使用代码调用API …

当WebGIS遇到智慧文旅-以长沙市不绕路旅游攻略为例

目录 前言 一、旅游数据组织 1、旅游景点信息 2、路线时间推荐 二、WebGIS可视化实现 1、态势标绘实现 2、相关位置展示 三、成果展示 1、第一天旅游路线 2、第二天旅游路线 3、第三天旅游路线 4、交通、订票、住宿指南 四、总结 前言 随着信息技术的飞速发展&…

93,【1】buuctf web [网鼎杯 2020 朱雀组]phpweb

进入靶场 页面一直在刷新 在 PHP 中&#xff0c;date() 函数是一个非常常用的处理日期和时间的函数&#xff0c;所以应该用到了 再看看警告的那句话 Warning: date(): It is not safe to rely on the systems timezone settings. You are *required* to use the date.timez…

如何在电脑上部署deepseek

由于免费的网页版经常显示服务器异常&#xff0c;并且每次打开网页麻烦&#xff0c;我们可以采用电脑部署的方法&#xff0c;V3和V2现在都很便宜&#xff0c;试了一下问了一下午问题也才0.1&#xff0c;而且现在注册就送14元&#xff0c;心动不如行动&#xff0c;快来薅羊毛&am…

SmartPipe完成新一轮核心算法升级

1. 增加对低质量轴段的修正 由于三维图纸导出造成某些轴段精度较差&#xff0c;部分管路段的轴线段不满足G1连续&#xff0c;SmartPipe采用算法对这种情况进行了修正&#xff0c;保证轴段在一定精度范围内光滑连续。 2. 优化对中文路径的处理 SmartPipeBatch批处理版本优化…

2.3学习总结

今天做了下上次测试没做出来的题目&#xff0c;作业中做了一题&#xff0c;看了下二叉树&#xff08;一脸懵B&#xff09; P2240&#xff1a;部分背包问题 先求每堆金币的性价比&#xff08;价值除以重量&#xff09;&#xff0c;将这些金币由性价比从高到低排序。 对于排好…

四川正熠法律咨询有限公司正规吗可信吗?

在纷繁复杂的法律环境中&#xff0c;寻找一家值得信赖的法律服务机构是每一个企业和个人不可或缺的需求。四川正熠法律咨询有限公司&#xff0c;作为西南地区备受瞩目的法律服务提供者&#xff0c;以其专注、专业和高效的法律服务&#xff0c;成为众多客户心中的首选。 正熠法…

【leetcode练习·二叉树拓展】快速排序详解及应用

本文参考labuladong算法笔记[拓展&#xff1a;快速排序详解及应用 | labuladong 的算法笔记] 1、算法思路 首先我们看一下快速排序的代码框架&#xff1a; def sort(nums: List[int], lo: int, hi: int):if lo > hi:return# 对 nums[lo..hi] 进行切分# 使得 nums[lo..p-1]…

FPGA学习篇——开篇之作

今天正式开始学FPGA啦&#xff0c;接下来将会编写FPGA学习篇来记录自己学习FPGA 的过程&#xff01; 今天是大年初六&#xff0c;简单学一下FPGA的相关概念叭叭叭&#xff01; 一&#xff1a;数字系统设计流程 一个数字系统的设计分为前端设计和后端设计。在我看来&#xff0…

DeepSeek R1 简易指南:架构、本地部署和硬件要求

DeepSeek 团队近期发布的DeepSeek-R1技术论文展示了其在增强大语言模型推理能力方面的创新实践。该研究突破性地采用强化学习&#xff08;Reinforcement Learning&#xff09;作为核心训练范式&#xff0c;在不依赖大规模监督微调的前提下显著提升了模型的复杂问题求解能力。 技…

Vue3学习笔记-模板语法和属性绑定-2

一、文本插值 使用{ {val}}放入变量&#xff0c;在JS代码中可以设置变量的值 <template><p>{{msg}}</p> </template> <script> export default {data(){return {msg: 文本插值}} } </script> 文本值可以是字符串&#xff0c;可以是布尔…

Android学习19 -- 手搓App

1 前言 之前工作中&#xff0c;很多时候要搞一个简单的app去验证底层功能&#xff0c;Android studio又过于重型&#xff0c;之前用gradle&#xff0c;被版本匹配和下载外网包折腾的堪称噩梦。所以搞app都只有找应用的同事帮忙。一直想知道一些简单的app怎么能手搓一下&#x…

深度解读 Docker Swarm

一、引言 随着业务规模的不断扩大和应用复杂度的增加,容器集群管理的需求应运而生。如何有效地管理和调度大量的容器,确保应用的高可用性、弹性伸缩和资源的合理分配,成为了亟待解决的问题。Docker Swarm 作为 Docker 官方推出的容器集群管理工具,正是在这样的背景下崭露头…

centos stream 9 安装 libstdc++-static静态库

yum仓库中相应的镜像源没有打开&#xff0c;libstdc-static在CRB这个仓库下&#xff0c;但是查看/etc/yum.repos.d/centos.repo&#xff0c;发现CRB镜像没有开启。 解决办法 如下图开启CRB镜像&#xff0c; 然后执行 yum makecache yum install glibc-static libstdc-static…

玉米苗和杂草识别分割数据集labelme格式1997张3类别

数据集格式&#xff1a;labelme格式(不包含mask文件&#xff0c;仅仅包含jpg图片和对应的json文件) 图片数量(jpg文件个数)&#xff1a;1997 标注数量(json文件个数)&#xff1a;1997 标注类别数&#xff1a;3 标注类别名称:["corn","weed","Bean…

Docker入门篇(Docker基础概念与Linux安装教程)

目录 一、什么是Docker、有什么作用 二、Docker与虚拟机(对比) 三、Docker基础概念 四、CentOS安装Docker 一、从零认识Docker、有什么作用 1.项目部署可能的问题&#xff1a; 大型项目组件较多&#xff0c;运行环境也较为复杂&#xff0c;部署时会碰到一些问题&#xff1…