23/76-LeNet

LeNet
早期成功的神经网络。
先使用卷积层来学习图片空间信息。
然后使用全连接层转换到类别空间。

在这里插入图片描述

#In[]
'''
LeNet,上世纪80年代的产物,最初为了手写识别设计
'''
from d2l import torch as d2l
import torch 
from torch import nn
from torch.nn.modules.loss import CrossEntropyLossfrom torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import Common_functions'''
LeNet:
两个卷积层,两个池化层,三个线性层
假定为MNIST设计,输入为(batch_size,1,28,28)
'''class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)net = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),padding=2),nn.Sigmoid(), #输出:(6,28,28)nn.AvgPool2d(kernel_size=(2,2)), #不指定stride默认不重叠 输出(6,14,14)nn.Conv2d(6,16,kernel_size=(5,5)),nn.Sigmoid(),#输出(16,10,10)nn.AvgPool2d(kernel_size=(2,2)),#输出(16,5,5)nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),#nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10)
)X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:X=layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)#In[]batch_size = 256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)#对evaluate_accuracy函数进行轻微修改
#使用GPU计算模型在数据集上的精度
#计算网络在测试数据集上面的准确率
#由于完整的测试数据集位于内存中,因此在模型使用GPU预测测试数据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):if isinstance(net,nn.Module):net.eval() #网络用于测试数据if not device:device = next(iter(net.parameters())).device #如果没有指定device设备,device设备则使用第一层网络参数的设备accumulator = d2l.Accumulator(2) #累加器里面包含两个元素for X,y in data_iter:if isinstance(X,list):X = [x.to(device) for x in X] #X为list类型时,需要加X里面每个元素都复制到device设备上面来else:X = X.to(device)y = y.to(device)accumulator.add(d2l.accuracy(net(X),y),y.numel()) #累加器第一个元素为在每一个batch_size中预测准确的个数,第二个元素为每一个batch_size中样本总数目,然后依次循环累加,得到测试数据集上面预测准确的总数目,以及数据集总数目return accumulator[0]/accumulator[1] #算出模型预测准确率def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):#手动初始化模型参数if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) #使用xavier_uniform分布初始化参数net.apply(init_weights)net.to(device)#将模型复制到gpu上面print('training on',device)loss = nn.CrossEntropyLoss() #定义lossoptim = torch.optim.SGD(net.parameters(),lr=lr) #定义优化器animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train_loss','train_acc','test_acc'])timer = d2l.Timer()num_batches = len(train_iter)for epoch in range(num_epochs):net.train()#模型开始训练,需要放在第一层循环里面,因为后面evaluate_accuracy_gpu()函数里面有net.eval(),将模型改变为测试状态,因此需要在每一个循环epoch后面手动再加上模型开始处于训练状态accumulator = d2l.Accumulator(3) #累加器for i,(X,y) in enumerate(train_iter):timer.start()optim.zero_grad()X = X.to(device)#将X复制到gpu上面y = y.to(device) #将y复制到gpu上面y_hat = net(X) #得到模型训练后的输出标签y_hatl = loss(y_hat,y)#计算每一个batch_size的lossl.backward() #计算梯度optim.step() #使用优化器更新模型参数with torch.no_grad():#不需要模型梯度accumulator.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])timer.stop()train_loss = accumulator[0]/accumulator[2] #从累加器里面获得所有训练集的loss之和train_acc = accumulator[1]/accumulator[2] #从累加器里面获得所有训练集的准确数之和if (i+1) % (num_batches // 5) == 0 or i == num_batches-1:animator.add(epoch+(i+1)/num_batches,(train_loss,train_acc,None))test_accuracy = evaluate_accuracy_gpu(net,test_iter) #每次训练完一个epoch后的模型用于测试数据集上面计算测试精确度animator.add(epoch+1,(None,None,test_accuracy))print(f'模型训练完最后一轮时 train_loss:{train_loss},train_acc:{train_acc},test_acc:{test_accuracy}')print(f'{num_epochs*accumulator[2]/timer.sum()}examples/second on {str(device)}')#打印出模型每秒能处理多少个样本数lr,num_epochs= 0.9,10
train_ch6(net,train_iter=train_iter,test_iter=test_iter,lr=lr,num_epochs=num_epochs,device=d2l.try_gpu())
'''
输出结果:
模型训练完最后一轮时 train_loss:0.4322478462855021,train_acc:0.8396666666666667,test_acc:0.8163
55954.65804440994examples/second on cuda:0
'''#训练
if torch.cuda.is_available():device = "cuda:0"
else:device = "cpu"
device = torch.device(device)Common_functions.train_device(net,train_iter,test_iter,lr=0.9,device=device)
# %%plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/240031.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue实战:两种方式创建Vue项目

文章目录 一、实战概述二、实战步骤(一)安装Vue CLI脚手架1、从Node.js官网下载LTS版本2、安装Node.js到指定目录3、配置Node.js环境变量4、查看node版本5、查看npm版本6、安装Vue Cli脚手架7、查看Vue Cli版本 (二)命令行方式构建…

Python多线程爬虫——数据分析项目实战详解

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 ChatGPT体验地址 文章目录 前言爬虫获取cookie网站爬取与启动CSDN爬虫爬虫启动将爬取内容存到文件中 多线程爬虫选择要爬取的用户 线程池 爬虫 爬虫是指一种自动化程序,能够模…

Jmeter 测试脚本录制器-HTTP 代理服务器

Jmeter 测试脚本录制器-HTTP 代理服务器 Jmeter 配置代理服务器代理服务器获取请求地址示例图配置步骤 浏览器配置代理Google 浏览器插件配置代理windows 本地网络配置代理 启动录制,生成证书生成证书导入证书Jmeter 配置证书 浏览器点击页面,录制请求地…

【深度学习】RTX2060 2080如何安装CUDA,如何使用onnx runtime

文章目录 如何在Python环境下配置RTX 2060与CUDA 101. 安装最新的NVIDIA显卡驱动2. 使用conda安装CUDA Toolkit3. 验证onnxruntime与CUDA版本4. 验证ONNX需求版本5. 安装ONNX与onnxruntime6. 编写ONNX推理代码 如何在Python环境下配置RTX 2060与CUDA 10 RTX 2060虽然是一款较早…

AI嵌入式K210项目(11)-SPI Flash读写

文章目录 前言一、K210的SPI二、Flash介绍三、实验过程总结 前言 这一章我们来学习下SPI及其应用,SPI 是一种高速的,全双工,同步的通信总线,由于其高速、同步和简单的特性,被广泛应用于各种微控制器和外围设备之间的通…

微信小程序防止截屏录屏

一、使用css添加水印 使用微信小程序原生的view和css给屏幕添加水印这样可以防止用户将小程序内的隐私数据进行截图或者录屏分享导致信息泄露,给小程序添加一个水印浮层。这样即使被截图或者拍照,也能轻松地确定泄露的源头。效果图如下: 代码…

MongoDB认证考试小题库

Free MongoDB C100DBA Exam Actual Questions 关于MongoDB C100 DBA 考试真题知识点零散整理 分片架构 应用程序 --> mongos --> 多个mongod对于应用来说,连接分片集群跟连接一台单机mongod服务器一样分片好处, 增加可用RAM、增加可用磁盘空间、…

初识 Elasticsearch 应用知识,一文读懂 Elasticsearch 知识文集(3)

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

F-44 显示字段调整补充

F-44 显示字段调整补充 网上有段资料清账格式的设置与账号相关,通过此次设置后,下次F-51付款清账时,系统默认按此格式显示。如果在格式设置中找不到适合的格式,用户可以自定义格式,通过事务代码O7Z4S配置行格式&#…

AI对决:ChatGPT与文心一言的比较

. 个人主页:晓风飞 专栏:数据结构|Linux|C语言 路漫漫其修远兮,吾将上下而求索 文章目录 引言ChatGPT与文心一言的比较Chatgpt的看法文心一言的看法Copilot的观点chatgpt4.0的回答 模型的自我评价自我评价 ChatGPT的优势在这里插入图片描述 文…

利用c 原生头文件完成JPEG全流程编码

骄傲一下,经过一个多月的努力,终于完成jpeg的全套编码。经验证此程序可以把摄像头yuv信号转为JPG图片。现在的程序还不完美,只能对长和宽尺寸是16倍数的信号转码。而且转码速度太慢,一帧1280720的图片要2秒多。此程序只能对yuv420…

Java生成四位数随机验证码

引言: 我们生活中登录的时候都要输入验证码,这些验证码是为了增加注册或者登录难度,减少被人用脚本疯狂登录注册导致的一系列危害,减少数据库的一些压力。 毕竟那些用脚本生成的账号都是垃圾账号 本次实践:生成这样的…

Docker Consul详解与部署示例

目录 Consul构成 Docker Consul 概述 Raft算法 服务注册与发现 健康检查 Key/Value存储 多数据中心 部署模式 consul-template守护进程 registrator容器 consul服务部署(192.168.41.31) 环境准备 搭建Consul服务 查看集群信息 registrato…

开源协议概览

身为程序员,我们不可避免的要和开源项目打交道,不管是我们自己做了些开源项目,还是使用开源项目,对各种开源协议的了解是必要的。 OSI(Open Source Initiative) OSI,开发源代码组织,是一个旨在推动开源软件…

【河海大学论文LaTeX+VSCode全指南】

河海大学论文LaTeXVSCode全指南 前言一、 LaTeX \LaTeX{} LATE​X的安装二、VScode的安装三、VScode的配置四、验证五、优化 前言 LaTeX \LaTeX{} LATE​X在论文写作方面具有传统Word无法比拟的优点,VScode作为一个轻量化的全功能文本编辑器,由于其极强的…

ZYNQ 7020 PL feature 解读

1. 组成 CLB, RAM, DSP, IO block,XADC, PCI-E, etc 2. CLK Each device in the Zynq-7000 family has up to 8 clock management tiles (CMTs), each consisting of one mixed-mode clock manager (MMCM) and one phase-locked loop (PLL). See Table 5. 2.1, Clock Distri…

记录下载安装rabbitmq(Linux) 并整合springboot--详细版(全)

下载rabbitmq(Linux): erlang压缩包: https://share.weiyun.com/TGhfV8eZ rabbitMq-server压缩包: https://share.weiyun.com/ZXbUwWHD (因为RabbitMQ采用 Erlang 实现的工业级的消息队列(MQ)服务器&#…

网络安全技术新手入门:利用永恒之蓝获取靶机控制权限

目录 前言 一、搜索永恒之蓝可用模块 二、使用攻击模块 三、配置攻击模块 四、攻击 五、总结 前言 相关法律声明:《中华人民共和国网络安全法》第二十七条 任何个人和组织不得从事非法侵入他人网络、干扰他人网络正常功能、窃取网络数据等危害网络安全的活动&…

element-ui表单验证同时用change与blur一起验证

项目场景: 提示:这里简述项目相关背景: 当审批时不通过审批意见要必须输入, 1:如果用change验证的话删除所有内容时报错是massage的提示,但是在失去焦点的时候报错就成了英文,如下图&#xf…

Qt 国产嵌入式操作系统实现文字转语音功能(TTS)

1.简介 本示例使用的CPU:rk3588。 操作系统:kylin V10 架构:aarch64 在Windows端,我们很容易想到使用Qt自带的类QTextToSpeech来实现文字转语音功能,Qt版本得在5.11.0以上才支持。但是在嵌入式平台,尤其…