pytorch-AutoEncoders实战

目录

  • 1. AutoEncoders回顾
  • 2. 实现网络结构
  • 3. 实现main函数

1. AutoEncoders回顾

如下图:AutoEncoders实际上就是重建自己的过程
在这里插入图片描述

2. 实现网络结构

创建类继承自nn.Model,并实现init和forward函数,init中实现encoder、decoder
直接上代码,其中encoder维度变化784->256->64->20,decoder是相反过程

import  torch
from    torch import nnclass AE(nn.Module):def __init__(self):super(AE, self).__init__()# [b, 784] => [b, 20]self.encoder = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 64),nn.ReLU(),nn.Linear(64, 20),nn.ReLU())# [b, 20] => [b, 784]self.decoder = nn.Sequential(nn.Linear(20, 64),nn.ReLU(),nn.Linear(64, 256),nn.ReLU(),nn.Linear(256, 784),nn.Sigmoid())def forward(self, x):""":param x: [b, 1, 28, 28]:return:"""batchsz = x.size(0)# flattenx = x.view(batchsz, 784)# encoderx = self.encoder(x)# decoderx = self.decoder(x)# reshapex = x.view(batchsz, 1, 28, 28)return x, None

3. 实现main函数

  • 加载minist数据集
  • 初始化网络
  • 初始化loss函数
  • 初始化优化器
  • 循环迭代数据输入模型、计算loss、优化器优化参数
    完整代码
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasetsfrom    ae import AE
#from    vae import VAEimport  visdomdef main():mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([transforms.ToTensor()]), download=True)mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([transforms.ToTensor()]), download=True)mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)x, _ = iter(mnist_train).next()print('x:', x.shape)device = torch.device('cuda')model = AE().to(device)#model = VAE().to(device)criteon = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)print(model)viz = visdom.Visdom()for epoch in range(1000):for batchidx, (x, _) in enumerate(mnist_train):# [b, 1, 28, 28]x = x.to(device)x_hat, _ = model(x)loss = criteon(x_hat, x)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch, 'loss:', loss.item(), 'kld:', kld.item())x, _ = iter(mnist_test).next()x = x.to(device)with torch.no_grad():x_hat, kld = model(x)viz.images(x, nrow=8, win='x', opts=dict(title='x'))viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/423311.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第13天|二叉树基础知识、递归遍历、迭代遍历、层序遍历、116. 填充每个节点的下一个右侧节点指针

目录 二叉树基础深度和高度满二叉树和完全二叉树二叉搜索树和平衡二叉搜索树二叉树节点定义前中后序遍历 递归遍历前序递归遍历—144. 二叉树的前序遍历 迭代遍历层序遍历116. 填充每个节点的下一个右侧节点指针1、题目描述2、思路3、code 二叉树基础 深度和高度 满二叉树和完…

XSS跨站脚本攻击及防护

什么是XSS攻击? XSS(Cross-Site Scripting,跨站脚本攻击)是一种代码注入攻击。攻击者在目标网站上注入恶意代码,当用户(被攻击者)登录网站时就会执行这些恶意代码,通过这些脚本可以读取cookie,session tokens,或者网站其他敏感的网…

Ubuntu WSL使用技巧

0 Preface/Foreword 1 默认为root用户 当下载完成Ubuntu之后,首次登录,当完成初始化后,提示输入新的用户名时候,直接点击右上角的X按钮,再重新登陆,系统会默认使用root权限登录。 2 root用户和普通用户切换…

阿里云社区领积分自动打卡Selenium IDE脚本

脚本 感觉打卡比较麻烦,要点开点按钮这种机械化的操作。 所以就自己整了个脚本: { “id”: “f9999777-9ad6-40e0-9435-4f105919c982”, “version”: “2.0”, “name”: “aliyun”, “url”: “https://developer.aliyun.com”, “tests”: [{ “id”…

ubuntu22安装docker

1、查看服务器系统信息 uname -a:显示内核名称、主机名、内核版本、处理器类型等信息。 lsb_release -a:显示有关 Ubuntu 发行版的详细信息,包括版本号、代号等。 free -h:查看系统内存使用情况。 df -h:查看磁盘空间使…

Vue2时间轴组件(TimeLine/分页、自动顺序播放、暂停、换肤功能、时间选择,鼠标快速滑动)

目录 1介绍背景 2实现原理 3组件介绍 4代码 5其他说明 1介绍背景 项目背景是 一天的时间轴 10分钟为一间隔 一天被划分成144个节点 一页面12个节点 代码介绍的很详细 可参考或者借鉴 2实现原理 对Element-plus滑块组件的二次封装 基于Vue2(2.6.14&#x…

Vue3 : ref 与 reactive

目录 一.ref 二.reactive 三.ref与reactive的区别 四.总结 一.ref 在 Vue 3 中,ref 是一个用于创建可读写且支持数据跟踪的响应式引用对象。它主要用于在组件内部创建响应式数据,这些数据可以是基本类型(如 number、string、boolean&…

深入理解全连接层:从线性代数到 PyTorch 中的 nn.Linear 和 nn.Parameter

文章目录 数学概念(全连接层,线性层)nn.Linear()nn.Parameter()Q1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_featur…

Vue的缓存组件 | 详解KeepAlive

引言 在Vue开发中,我们经常需要处理大量的组件渲染和销毁操作,这可能会影响应用的性能和用户体验。而Vue的KeepAlive组件提供了一种简便的方式来优化组件的渲染和销毁流程,通过缓存已经渲染的组件来提升应用的性能。 本文将详细介绍Vue的Ke…

即插即用篇 | YOLOv10 引入矩形自校准模块RCM | ECCV 2024

本改进已同步到YOLO-Magic框架! 语义分割是许多应用的重要任务,但要在有限的计算成本下实现先进性能仍然非常具有挑战性。在本文中,我们提出了CGRSeg,一个基于上下文引导的空间特征重建的高效且具有竞争力的分割框架。我们精心设计了一个矩形自校准模块,用于空间特征重建和…

经典RNA-seq分析流程1

RNA-seq分析有很多流程, 一般都是上游linux工具获取表达矩阵数据,然后就可以使用下游R包进行处理了,要么是差异DEG表达gene等分析; 因为下游分析其实R包是明确的,毕竟有很多生信分析教程,但是上游的linux…

无人机之处理器篇

无人机的处理器是无人机系统的核心部件之一,它负责控制无人机的飞行、数据处理、任务执行等多个关键功能。以下是对无人机处理器的详细解析: 一、处理器类型 无人机中使用的处理器主要包括以下几种类型: CPU处理器:CPU是无人机的…

神经网络多层感知器异或问题求解-学习篇

多层感知器可以解决单层感知器无法解决的异或问题 首先给了四个输入样本,输入样本和位置信息如下所示,现在要学习一个模型,在二维空间中把两个样本分开,输入数据是个矩阵,矩阵中有四个样本,样本的维度是三维…

Unity全面取消Runtime费用 安装游戏不再收版费

Unity宣布他们已经废除了争议性的Runtime费用,该费用于2023年9月引入,定于1月1日开始收取。Runtime费用起初是打算根据使用Unity引擎安装游戏的次数收取版权费。2023年9月晚些时候,该公司部分收回了计划,称Runtime费用只适用于订阅…

[数据集][目标检测]车窗状态检测车窗开关检测数据集VOC+YOLO格式299张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):299 标注数量(xml文件个数):299 标注数量(txt文件个数):299 标注类别…

应用程序已被 Java 安全阻止:Java 安全中的添加的例外站点如何对所有用户生效

如题:应用程序已被 Java 安全阻止,如下图所示: 在寻找全局配置的时候花了一个上午的时间,到处搜解决方法,都不可行。最后还是参考官方的文档配置好了。如果你碰到了同样的问题,这篇文章一定可以帮到你。 环…

论文阅读:AutoDIR Automatic All-in-One Image Restoration with Latent Diffusion

论文阅读:AutoDIR: Automatic All-in-One Image Restoration with Latent Diffusion 这是 ECCV 2024 的一篇文章,利用扩散模型实现图像恢复的任务。 Abstract 这篇文章提出了一个创新的 all-in-one 的图像恢复框架,融合了隐扩散技术&#x…

【重学 MySQL】二十八、SQL99语法新特性之自然连接和 using 连接

【重学 MySQL】二十八、SQL99语法新特性之自然连接和 using 连接 自然连接(NATURAL JOIN)USING连接总结 SQL99语法在SQL92的基础上引入了一些新特性,其中自然连接(NATURAL JOIN)和USING连接是较为显著的两个特性。 自…

《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》P84

更正卷积与相关微课中互相关运算动画中的索引。 1-D correlation rectwave 禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》 禹晶、肖创柏、廖庆敏《数字图像处理》资源二维码

性能测试【Locust】基本使用介绍

一.前言 Locust是一款易于使用的分布式负载测试工具,基于事件驱动,使用轻量级执行单元(如协程)来实现高并发。 二.基本使用 以下是Locust性能测试使用的一个基础Demo示例,该示例有安装Locust、编写测试脚本、启动测…