[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat =   self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter()    interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/119162.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java-API简析_java.net.Inet4Address类(基于 Latest JDK)(浅析源码)

【版权声明】未经博主同意,谢绝转载!(请尊重原创,博主保留追究权) https://blog.csdn.net/m0_69908381/article/details/132643590 出自【进步*于辰的博客】 因为我发现目前,我对Java-API的学习意识比较薄弱…

百度百科词条怎么更新?怎么能顺利更新百科词条?

企业和个人百度百科词条的更新对于他们来说都具有重要的意义,具体如下: 对企业来说: 塑造品牌形象:百度百科是一个常被用户信任并参考的知识平台,通过更新企业词条可以提供准确、全面的企业信息,帮助企业塑…

SQL-basics

SQL 一些常用的查询语句用法 SQL 中的聚合函数 SQL 中的子查询 SQL 使用实例 SELECT F_NAME , L_NAME FROM EMPLOYEES WHERE ADDRESS LIKE ‘%Elgin,IL%’; SELECT F_NAME , L_NAME FROM EMPLOYEES WHERE B_DATE LIKE ‘197%’; SELECT * FROM EMPLOYEES WHERE (SALARY BET…

C#获取屏幕缩放比例

现在1920x1080以上分辨率的高分屏电脑渐渐普及了。我们会在Windows的显示设置里看到缩放比例的设置。在Windows桌面客户端的开发中,有时会想要精确计算窗口的面积或位置。然而在默认情况下,无论WinForms的Screen.Bounds.Width属性还是WPF中SystemParamet…

QTday4

实现闹钟功能 1》 头文件 #ifndef BURGER_H #define BURGER_H#include <QWidget> #include <QLabel> #include <QLineEdit> #include <QPushButton> #include <QTextEdit> #include <QTimerEvent> //定时器事件类 #include <QDateTim…

JavaScript基础语法01——初识JavaScript

哈喽&#xff0c;大家好&#xff0c;我是雷工&#xff01; 最近有项目用到KingFusion软件&#xff0c;由于KingFusion是B/S架构的客户端组态软件&#xff0c;因此在学习KingFusion产品时会涉及许多前端的知识。 像JavaScript语言就是需要用的&#xff0c;俗话说&#xff1a;活到…

合宙Air724UG LuatOS-Air LVGL API控件--下拉框 (Dropdown)

下拉框 (Dropdown) 在显示选项过多时&#xff0c;可以通过下拉框收起多余选项。只为用户展示列表中的一项。 示例代码 -- 回调函数 event_handler function(obj, event)if (event lvgl.EVENT_VALUE_CHANGED) thenprint("Option:", lvgl.dropdown_get_symbol(obj)…

类ChatGPT大模型LLaMA及其微调模型

1.LLaMA LLaMA的模型架构:RMSNorm/SwiGLU/RoPE/Transfor mer/1-1.4T tokens 1.1对transformer子层的输入归一化 对每个transformer子层的输入使用RMSNorm进行归一化&#xff0c;计算如下&#xff1a; 1.2使用SwiGLU替换ReLU 【Relu激活函数】Relu(x) max(0,x) 。 【GLU激…

项目:智慧教室(cubemx+webserver)

第一章&#xff1a;需求与配置 一。项目需求 二。实现外设控制 注意&#xff1a; 先配置引脚&#xff0c;再配置外设。否则会出现一些不可预料的问题 1.时钟&#xff0c;串口&#xff0c;灯&#xff0c;蜂鸣器配置 &#xff08;1&#xff09;RCC配置为外部时钟&#xff0c;修…

性能可靠it监控系统,性能监控软件的获得来源有哪些

性能可靠的IT监控系统是企业IT运维的重要保障之一。以下是一个性能可靠的IT监控系统应该具备的特点&#xff1a; 高可用性 高可用性是IT监控系统的一个重要特点&#xff0c;它可以保证系统在24小时不间断监控的同时&#xff0c;保证系统服务的可用性和稳定性。为了实现高可用性…

JVM ZGC垃圾收集器

ZGC垃圾收集器 ZGC&#xff08;“Z”并非什么专业名词的缩写&#xff0c;这款收集器的名字就叫作Z Garbage Collector&#xff09;是一款在JDK 11中新加入的具有实验性质[1]的低延迟垃圾收集器&#xff0c;是由Oracle公司研发的。 ZGC收集器是一款基于Region内存布局的&#…

时间语义与窗口

时间语义 在Flink中&#xff0c;时间语义分为两种 &#xff1a; 处理时间和事件时间。时间语义与窗口函数是密不可分的。以窗口为单位进行某一段时间内指标统计&#xff0c;例如想要统计8点-9点的某个页面的访问量&#xff0c;此时就需要用到了窗口函数&#xff0c;这里的关键…

多目标应用:基于多目标向日葵优化算法(MOSFO)的微电网多目标优化调度MATLAB

一、微网系统运行优化模型 参考文献&#xff1a; [1]李兴莘,张靖,何宇,等.基于改进粒子群算法的微电网多目标优化调度[J].电力科学与工程, 2021, 37(3):7 二、多目标向日葵优化算法 多目标向日葵优化算法&#xff08;Multi-objective sunflower optimization&#xff0c;MOS…

kubesphere中部署grafana实现dashboard以PDF方式导出

1&#xff0c;部署grafana-image-renderer 2&#xff0c;部署grafana GF_RENDERING_SERVER_URL http://ip:30323/render #grafana-image-renderer地址 GF_RENDERING_CALLBACK_URL http://ip:32403/ #grafana地址 GF_LOG_FILTERS rend…

【CSS】简记CSS效果:通过transition(动画过渡属性)实现侧边栏目滑入滑出

需求 在资金明细的页面中&#xff0c;点击按钮时筛选区域从左侧滑出&#xff0c;完成筛选点击确认后调用接口完成数据查询&#xff0c;筛选区域滑入左侧&#xff1b; 基于微信小程序页面实现 wxml代码 <view><!-- 操作按钮 --><button type"primary&qu…

docker笔记8:Docker网络

1.是什么 1.1 docker不启动&#xff0c;默认网络情况 ens33 lo virbr0 在CentOS7的安装过程中如果有选择相关虚拟化的的服务安装系统后&#xff0c;启动网卡时会发现有一个以网桥连接的私网地址的virbr0网卡(virbr0网卡&#xff1a;它还有一个固定的默认IP地址192.168.122…

Thymeleaf常见属性

参考文档 thymeleaf 语法——th:text默认值、字符串连接、th:attr、th:href 传参、th:include传参、th:inline 内联、th:each循环、th:with、th:if_猎人在吃肉的博客-CSDN博客 代码演示 Controller public class TestController {AutowiredMenuService menuService;GetMapp…

【洛谷】P3853 路标设置

原题链接&#xff1a;https://www.luogu.com.cn/problem/P3853 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 整体思路&#xff1a;二分答案 由题意知&#xff0c;公路上相邻路标的最大距离定义为该公路的“空旷指数”。在公路上增设一些路标&…

MySQL的备份与恢复以及日志管理

目录 一、数据备份的重要性 二、数据库备份的分类 1、物理备份 2、逻辑备份 &#xff08;1&#xff09;完全备份&#xff1a;每次对数据进行完整的备份 &#xff08;2&#xff09;差异备份&#xff1a;备份自从上次完全备份之后被修改的过文件 &#xff08;3&#xff09…

yolov5手机版移植

感谢阅读 运行export.py然后百度一个onnx转化工具下载yolov5移动版文件和ncnn修改代码CMakeLists.txt修改修改param的参数![在这里插入图片描述](https://img-blog.csdnimg.cn/7c929414761840db8a2556843abcb2b3.jpeg)yolov5ncnn_jni.cpp修改修改stride16和stride32完工 运行ex…