Pytorch Advanced(二) Variational Auto-Encoder

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。

自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的深层次特征再还原成为原来的数据特征。那么如何保证从压缩到解码两部分,原数据和解码数据保持一致呢?这就是要训练的过程。

如何理解降维?如果压缩的过程是卷积,维度可以根据核的个数变化,特征维度因此而改变。


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')sample_dir = 'samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3dataset = torchvision.datasets.MNIST(root='../../data',train=True,transform=transforms.ToTensor(),download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)

模型搭建:这里搭建的是一个变分自编码,Variational Autoencoder

那么变分自编码是为了解决什么问题呢? ——- 其主要思想还是希望学习隐层变量,并将其用来表示原始数据,但是它加另一个条件, 即隐层变量能学习原始数据的分布, 并反过来生产一些和原始数据相似的数据(这有啥用?—-可用于图片修复,让图片按训练集的数据分布变化)。

变分自编码 (Variational Autoencoder) 为了让隐层抓住输入数据特性, 而不是简单的输出数据=输入数据,他在隐层中加入随机噪声(单位高斯噪声)(这个过程也叫reparametrize),以确保隐层能较好抽象输入数据特点。

代码中怎么做的呢?

1、编码过程中我们保存了第二层线性层的输出。其中第二层包含有fc2与fc3两部分,他们是并联的。

2、给隐藏层加入随机噪声,作为解码的输入

class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encode(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_reconst = self.decode(z)return x_reconst, mu, log_var

训练:由于训练中加入了噪声,所以损失值的结构也因此改变。一部分来源于解码内容核原内容的相似度,另一部分是kl_div,具体是什么意义需查看论文。

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Start training
for epoch in range(num_epochs):for i, (x, _) in enumerate(data_loader):# Forward passx = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# Compute reconstruction loss and kl divergence# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# Backprop and optimizeloss = reconst_loss + kl_divoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 10 == 0:print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))with torch.no_grad():# Save the sampled imagesz = torch.randn(batch_size, z_dim).to(device)out = model.decode(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))# Save the reconstructed imagesout, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

模型训练完成了之后该如何使用这个模型呢?

model.decode()是一个解码的过程,我们给他一个随机的中间特征z就可以输出一个数字图片了。

z = torch.randn(1,z_dim).to(device)
out = model.decode(z)
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

有了随机的一张图片之后,我们把他完整的放入模型中,生成了和输入相似的一张图片,也没看出来是修复了图像......

out,_,_ = model(out) 
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/135656.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Unity】ShaderGraph应用(浮动气泡)

【Unity】ShaderGraph应用(浮动气泡) 实现效果 一、实现的方法 1.使用节点介绍 Position:获取模型的顶点坐标 Simple Noise:简单的噪声,用于计算顶点抖动 Fresnel Effect:菲涅耳效应,用于实现气泡效果 计算用节点 Add&…

小程序键盘没有【小数点】输入

<input v-model"formData.number" :auto-height"true" placeholder"请输入" confirm-type"done" type"digit" maxlength"11" input"inputNumber" />number&#xff1a;数字键盘&#xff08;没有小…

【kafka】mac环境安装kafka

生产中使用到的中间件使用心得记录&#xff0c;感兴趣可以关注我一起学习&#xff5e; 环境&#xff1a; 硬件&#xff1a;mac 软件&#xff1a;kafka v3.0.0 安装步骤&#xff1a; 终端安装&#xff1a; 通过brew安装&#xff0c;会默认直接安装最新的版本 步骤1&#xf…

华为云云耀云服务器L实例评测|用Python的Flask框架加Nginx实现一个通用的爬虫项目

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;CSDN领军人物&#xff0c;全栈领域优质创作者✌&#xff0c;CSDN博客专家&#xff0c;阿里云社区专家博主&#xff0c;2023年6月CSDN上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;AWS/阿里云资深使用…

计算机专业毕业设计项目推荐06-工作室管理系统(Java+Vue+Mysql)

工作室管理系统&#xff08;JavaSpringVueMysql&#xff09; **介绍****系统总体开发情况-功能模块****各部分模块实现****最后想说的****联系方式** 介绍 本系列(后期可能博主会统一为专栏)博文献给即将毕业的计算机专业同学们,因为博主自身本科和硕士也是科班出生,所以也比较…

FPGA----VCU128的DDR4无法使用问题(全网唯一)

1、在Vivado 2019.1版本中使用DDR4的IP核会遇到如下图所示的错误&#xff0c;即便过了implementation生成了bit&#xff0c;DDR4也无法正常启动。 2、解决办法&#xff0c;上xilinx社区搜一下就知道了 AMD Customer Communityhttps://support.xilinx.com/s/article/69035?lan…

(JavaEE) 多线程基础3——多线程的代码案例 (单例模式, 阻塞队列,定时器)详解!!!

​​​​​​​ 目录 单例模式 什么是单例模式&#xff1f; —— “饿汉模式” —— “懒汉模式” ——懒汉模式-多线程版 ——懒汉模式-多线程版&#xff08;改进版&#xff09; 总结“懒汉模式”—— 多线程&#xff08;线程安全版&#xff09; 的要点 阻塞队列 什么…

K8s(Kubernetes)学习(六)——Ingress

第六章 Ingress 什么是 IngressIngress 和 Service 区别Ingress 控制器 Traefik 使用Ingress Route的定义 1 简介 https://kubernetes.io/zh-cn/docs/concepts/services-networking/ingress/ Ingress 是一种 Kubernetes 资源类型&#xff0c;它允许在 Kubernetes 集群中暴露…

【TCP】滑动窗口、流量控制 以及拥塞控制

滑动窗口、流量控制 以及拥塞控制 1. 滑动窗口&#xff08;效率机制&#xff09;2. 流量控制&#xff08;安全机制&#xff09;3. 拥塞控制&#xff08;安全机制&#xff09; 1. 滑动窗口&#xff08;效率机制&#xff09; TCP 使用 确认应答 策略&#xff0c;对每一个发送的数…

深入解读什么是期权的内在价值和时间价值?

期权品种越来越丰富&#xff0c;对于大家套利对冲都有很多的选择。而有些初学者对时间价值一直不理解&#xff0c;今天呢&#xff0c;就给大家讲一讲深入解读什么是期权的内在价值和时间价值&#xff1f;本文来自&#xff1a;期权酱 01在期权交易过程中&#xff0c;想必大家都会…

抖音seo矩阵系统源代码分享

技术开发注意事项&#xff1a; 确定业务需求&#xff1a;在开发前&#xff0c;需要明确抖音矩阵系统的业务需求&#xff0c;了解用户的需求和使用习惯&#xff0c;明确系统的功能、性能和安全需求。 选择合适的技术方案&#xff1a;根据系统的需求和复杂度&#xff0c;选择合适…

@EventListener 监听事件 ,在同一个虚拟机中如何保证顺序执行

文章目录 前言EventListener 监听事件 &#xff0c;在同一个虚拟机中如何保证顺序执行1. 设计原理2. 具体编码2.1. 编码事件监听器2.2. 制作一个生成序号方法2.3. 制作测试代码2.4. 测试结果 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;…

如何在 Excel 中进行加,减,乘,除

在本教程中&#xff0c;我们将执行基本的算术运算&#xff0c;即加法&#xff0c;减法&#xff0c;除法和乘法。 下表显示了我们将使用的数据以及预期的结果。 | **S / N** | **算术运算符** | **第一个号码** | **第二个号码** | **结果** | | 1 | 加法&#xff08;&#xff…

mysql报错You do not have the SUPER privilege and binary logging is enabled

创建触发器,显示You do not have the SUPER privilege and binary logging is enabled CREATE TRIGGER delete_lottery_template AFTER DELETE ON lottery_info FOR EACH ROW BEGINDELETE FROM lottery_template WHERE lottery_template_id old.lottery_template_id; END; …

C#,《小白学程序》第二十六课:大数乘法(BigInteger Multiply)的Toom-Cook 3算法及源程序

凑数的&#xff0c;仅供参考。 1 文本格式 /// <summary> /// 《小白学程序》第二十六课&#xff1a;大数&#xff08;BigInteger&#xff09;的Toom-Cook 3乘法 /// Toom-Cook 3-Way Multiplication /// </summary> /// <param name"a"></par…

【AI视野·今日NLP 自然语言处理论文速览 第三十五期】Mon, 18 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Mon, 18 Sep 2023 Totally 51 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computation and Language Papers "Merge Conflicts!" Exploring the Impacts of External Distractors to Parametric Knowledge Gra…

6.1、Flink数据写入到文件

1、前言 Flink API 提供了FileSink连接器&#xff0c;来帮助我们将数据写出到文件系统中去 版本说明&#xff1a;java1.8、flink1.17 官网链接&#xff1a;官网 2、Format Types - 指定文件格式 FileSink 支持 Row-encoded 、Bulk-encoded 两种格式写入文件系统 Row-encode…

Vulnhub实战-DC9

前言 本次的实验靶场是Vulnhub上面的DC-9&#xff0c;其中的渗透测试过程比较多&#xff0c;最终的目的是要找到其中的flag。 一、信息收集 对目标网络进行扫描 arp-scan -l 对目标进行端口扫描 nmap -sC -sV -oA dc-9 192.168.1.131 扫描出目标开放了22和80两个端口&a…

GFS分布式文件系统

目录 1、GFS理论 1.1、概述 1.2、GFS的基本概念和架构&#xff1a; 1.2.1. 基本概念 1.2.2. 架构&#xff1a; 1.3、GFS的基本组成部分包括&#xff1a; 1.4、GFS具有以下几个关键特点&#xff1a; 1.5、GlusterFS特点 1.6、GFS的术语 1.7、GlusterFS 采用模块化、堆…

【C++】STL简介 | string类的常用接口

目录 STL简介 学string类前的铺垫 概念 为什么要学string类 string类的底层&#xff08;了解&#xff09; 编码表的故事 string类的常用接口与应用 3个必掌握的构造 赋值 访问字符operator[] 初识迭代器&#xff08;iterator&#xff09; 反向迭代器 用范围for遍历…