[PyTorch][chapter 53][Auto Encoder 实战]

前言:

     结合手写数字识别的例子,实现以下AutoEncoder

     ae.py:  实现autoEncoder 网络

     main.py: 加载手写数字数据集,以及训练,验证,测试网络。

左图:原图像

右图:重构图像

 ----main-----

 每轮训练时间 : 91
0 loss: 0.02758789248764515

 每轮训练时间 : 95
1 loss: 0.024654878303408623

 每轮训练时间 : 149
2 loss: 0.018874473869800568

目录:

     1: AE 实现

     2: main 实现


一  ae(AutoEncoder) 实现

  文件名: ae.py

               模型的搭建

   注意点:

            手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

它的标签就是输入的x, 需要重构本身

自编码器(autoencoder, AE)是一类在半监督学习和非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  。

自编码器包含编码器(encoder)和解码器(decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

   

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass AE(nn.Module):def __init__(self,hidden_size=10):super(AE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]self.decoder = nn.Sequential(nn.Linear(in_features=hidden_size, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''m= x.size(0)x = x.view(m, 784)hidden= self.encoder(x)x =   self.decoder(hidden)#reshapex = x.view(m,1,28,28)return x

二 main 实现

  文件名: main.py

  作用:

      加载数据集

     训练模型

     测试模型泛化能力

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from ae import AE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =AE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat = model(x)loss = criteon(x_hat, x)#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter()    interval = end - startprint("\n 每轮训练时间 :",int(interval))print(epoch, 'loss:',loss.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/117386.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

bazel远程缓存(Remote Cache)

原理 您可以将服务器设置为构建输出(即这些操作输出)的远程缓存。这些输出由输出文件名列表及其内容的哈希值组成。借助远程缓存,您可以重复使用其他用户的 build 中的构建输出,而不是在本地构建每个新输出。 增量构建极大的提升…

开启EMQX的SSL模式及SSL证书生成流程

生成证书 首先:需要安装Openssl 以下是openssl命令 生成CA证书 1.openssl genrsa -out rootCA.key 2048 2.openssl req -x509 -new -nodes -key rootCA.key -sha256 -days 3650 -subj "/CCN/STShandong/Ljinan/Oyunding/OUplatform/CNrootCA" -out ro…

unity 发布apk,在应用内下载安装apk(用于更大大版本)

*注意事项: 1,andriod 7.0 和 android 8.0是安卓系统的分水岭,需要分开来去实现相关内容2,注意自己的包名,在设置一些共享文件的时候需要放自己的包名3,以下是直接用arr包放入unity中直接使用的,不需要导入…

尚硅谷SpringMVC (5-8)

五、域对象共享数据 1、使用ServletAPI向request域对象共享数据 首页&#xff1a; Controller public class TestController {RequestMapping("/")public String index(){return "index";} } <!DOCTYPE html> <html lang"en" xmln…

安达发|APS软件排程规则及异常处理方案详解

随着科技的发展&#xff0c;工业生产逐渐向智能化、自动化方向发展。APS(高级计划与排程)软件作为一种集成了先进技术和理念的工业软件&#xff0c;可以帮助企业实现生产过程的优化和控制。其中&#xff0c;排程规则是APS软件的核心功能之一&#xff0c;它可以帮助企业合理安排…

时序预测 | MATLAB实现CNN-BiGRU卷积双向门控循环单元时间序列预测

时序预测 | MATLAB实现CNN-BiGRU卷积双向门控循环单元时间序列预测 目录 时序预测 | MATLAB实现CNN-BiGRU卷积双向门控循环单元时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.MATLAB实现CNN-BiGRU卷积双向门控循环单元时间序列预测&#xff1b; 2.运行环境…

Docker技术--Docker简介和架构

1.Docker简介 (1).引入 我们之前学习了EXSI&#xff0c;对于虚拟化技术有所了解&#xff0c;但是我们发现类似于EXSI这样比较传统的虚拟化技术是存在着一定的缺陷:所占用的资源比较多&#xff0c;简单的说&#xff0c;就是你需要给每一个用户提供一个操作平台&#xff0c;这一个…

一款windows的终端神奇,类似mac的iTem2

终于找到了一款windows的终端神奇。类似mac的iTem2 来&#xff0c;上神器 cmder cmder是一款windows的命令行工具&#xff0c;就是我们的linux的终端&#xff0c;用起来和linux的命令一样。所以我们今天要做的是安装并配置cmder ![在这里插入图片描述](https://img-blog.csdni…

Lnmp架构-Redis

网站&#xff1a;www.redis.cn redis 部署 make的时候需要gcc和make 如果在纯净的环境下需要执行此命令 [rootserver3 redis-6.2.4]# yum install make gcc -y 注释一下这几行 vim /etc/redis/6739.conf 2.Redis主从复制 设置 11 是master 12 13 是slave 在12 上 其他节…

python 深度学习 解决遇到的报错问题4

目录 一、DLL load failed while importing _imaging: 找不到指定的模块 二、Cartopy安装失败 三、simplejson.errors.JSONDecodeError: Expecting value: line 1 column 1 (char 0) 四、raise IndexError("single positional indexer is out-of-bounds") 五、T…

QT Creator工具介绍及使用

一、QT的基本概念 QT主要用于图形化界面的开发&#xff0c; QT是基于C编写的一套界面相关的类库&#xff0c;如进程线程库&#xff0c;网络编程的库&#xff0c;数据库操作的库&#xff0c;文件操作的库等。 如何使用这个类库&#xff1a;类库实例化对象(构造函数) --> 学习…

简单了解ICMP协议

目录 一、什么是ICMP协议&#xff1f; 二、ICMP如何工作&#xff1f; 三、ICMP报文格式 四、ICMP的作用 五、ICMP的典型应用 5.1 Ping程序 5.2 Tracert(Traceroute)路径追踪程序 一、什么是ICMP协议&#xff1f; ICMP因特网控制报文协议是一个差错报告机制&#xff0c;…

【【萌新的STM32学习-18 中断的基本概念3】】

萌新的STM32学习-18 中断的基本概念3 EXTI和IO映射的关系 AFIO简介&#xff08;F1&#xff09; Alternate Function IO 复用功能IO 主要用于重映射和外部中断映射配置 1.调试IO配置 来自AFIO_MAPR[26:24] , 配置JTAG/SWD的开关状态 &#xff08;这个我们并不用太过深刻的关注&…

腾讯云网站备案详细流程_审核时间说明

腾讯云网站备案流程先填写基础信息、主体信息和网站信息&#xff0c;然后提交备案后等待腾讯云初审&#xff0c;初审通过后进行短信核验&#xff0c;最后等待各省管局审核&#xff0c;前面腾讯云初审时间1到2天左右&#xff0c;最长时间是等待管局审核时间&#xff0c;网站备案…

【易售小程序项目】修改“我的”界面前端实现;查看、重新编辑、下架自己发布的商品【后端基于若依管理系统开发】

文章目录 “我的”界面修改效果界面实现界面整体代码 查看已发布商品界面效果商品数据表后端上架、下架商品ControllerMapper 界面整体代码back方法 编辑商品、商品发布、保存草稿后端商品校验方法Controller 页面整体代码 “我的”界面修改 效果 界面实现 界面的实现使用了一…

DevOps理念:开发与运维的融合

在现代软件开发领域&#xff0c;DevOps 不仅仅是一个流行的词汇&#xff0c;更是一种文化、一种哲学和一种方法论。DevOps 的核心理念是通过开发和运维之间的紧密合作&#xff0c;实现快速交付、高质量和持续创新。本文将深入探讨 DevOps 文化的重要性、原则以及如何在团队中实…

Python入门教程 - 基本语法 (一)

目录 一、注释 二、Python的六种数据类型 三、字符串、数字 控制台输出练习 四、变量及基本运算 五、type()语句查看数据的类型 六、字符串的3种不同定义方式 七、数据类型之间的转换 八、标识符命名规则规范 九、算数运算符 十、赋值运算符 十一、字符串扩展 11.1…

Ubuntu20.04下安装搜狗输入法Linux版

Ubuntu20.04下安装搜狗输入法Linux版 参考搜狗输入法的官网安装指南&#xff1b; 第一步&#xff1a;打开搜狗输入法官网&#xff1b; https://shurufa.sogou.com/ 点击X86_64后将会自动跳转到搜狗输入法的安装指南中&#xff1b; 安装指南 Ubuntu搜狗输入法安装指南 搜狗…

iOS开发Swift-7-得分,问题序号,约束对象,提示框,类方法与静态方法-趣味问答App

1.根据用户回答计算得分 ViewController.swift: import UIKitclass ViewController: UIViewController {var questionIndex 0var score 0IBOutlet weak var questionLabel: UILabel!IBOutlet weak var scoreLabel: UILabel!override func viewDidLoad() {super.viewDidLoad()…

如何停止一个正在运行的线程

使用共享变量的方式 在这种方式中&#xff0c;之所以引入共享变量&#xff0c;是因为该变量可以被多个执行相同任务的线程用来作为是否中断的信号&#xff0c;通知中断线程的执行。 使用interrupt方法终止线程 如果一个线程由于等待某些事件的发生而被阻塞&#xff0c;又该怎样…