Pytorch-CNN-Mnist

文章目录

  • model.py
  • main.py
  • 网络设置
  • 注意事项及改进
  • 运行截图

model.py

import torch.nn as nn
class CNN_cls(nn.Module):def __init__(self,in_dim=28*28):super(CNN_cls,self).__init__()self.conv1 = nn.Conv2d(1,32,1,1)self.pool1 = nn.MaxPool2d(2,2)self.conv2 = nn.Conv2d(32,64,1,1)self.pool2 = nn.MaxPool2d(2,2)self.conv3 = nn.Conv2d(64,128,1,1)self.lin1 = nn.Linear(128*7*7,512)self.lin2 = nn.Linear(512,64)self.lin3 = nn.Linear(64,10)self.relu = nn.ReLU()def forward(self,x):x = self.conv1(x)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x = self.conv3(x)x = self.relu(x)x = x.view(-1,128*7*7)x = self.lin1(x)x = self.relu(x)x = self.lin2(x)x = self.relu(x)x = self.lin3(x)return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import CNN_clsseed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = CNN_cls()train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_test, shuffle=True)optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()print("****************Begin Training****************")
net.train()
for epoch in range(epochs):run_loss = 0correct_num = 0for batch_idx, (data, target) in enumerate(train_loader):out = net(data)_,pred = torch.max(out,dim=1)optimizer.zero_grad()loss = criterion(out,target)loss.backward()run_loss += lossoptimizer.step()correct_num  += torch.sum(pred==target)print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):out = net(data)_,pred = torch.max(out,dim=1)test_loss += criterion(out,target)test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

网络设置

在CNN_cls里面查看。

注意事项及改进

1.注意第一个输入通道是1,因为是灰度图像。
2.可以考虑加入GPU

运行截图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/135134.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web 第一步:HTTP 协议(基础)

这里是JavaWeb的开头部分!那么先解释一下吧: Web:全球广域网,也称为万维网(www),能够通过浏览器访问的网站。 JavaWeb:是用Java技术来解决相关 Web 互联网领域的技术栈。 &#xf…

vite和webpack的区别

vite和webpack的区别 1、前言2、Webpack2.1 Webpack简述2.2 Webpack常用插件 3、Vite3.1 Vite简述3.2 Vite插件推荐 4、区别4.1 开发模式不同4.2 打包效率不同4.3 插件生态不同4.4 配置复杂度不同4.5 热更新机制不同 5、总结 1、前言 Webpack和Vite是现代前端开发中非常重要的…

深度解析shell脚本的命令的原理之ls

ls是一个常用的Unix和Linux命令,它的功能是列出目录内容。当运行ls命令时,操作系统会执行一系列步骤,以获取和显示指定目录中的文件和子目录。以下是对这个命令的深度解析: 解析参数和选项:首先,ls命令会解…

Linux:centos9的本地yum仓库配置

其实9和7的配置方法是差不多一样的,只不过你使用7的本地yum仓库里面直接挂载就可以直接把仓库位置指向挂载点 具体可以看我往期文章,但是先看完我下面的描述再去看我链接的文章才能看懂如何配置centos9的yum仓库 Linux:YUM仓库服务_鲍海超-…

Android Media3 ExoPlayer 开启缓存功能

ExoPlayer 开启播放缓存功能,在下次加载已经播放过的网络资源的时候,可以直接从本地缓存加载,实现为用户节省流量和提升加载效率的作用。 方法一:采用 ExoPlayer 缓存策略 第 1 步:实现 Exoplayer 参考 Exoplayer 官…

利用Python将dataframe格式的所有列的数据类型转换为分类数据类型

一、样例理解 import pandas as pd import numpy as np# 创建测试数据 feature_names [col1 , col2, col3, col4, col5, col6] values np.random.randint(20, size(10,6))dataset pd.DataFrame(data values, columns feature_names)print("转换前的数据为\n",d…

NetSuite知识会汇编-管理员篇顾问篇2023

本月初,开学之际,我们发布了《NetSuite知识会汇编-用户篇 2023》,这次发布《NetSuite知识会汇编-管理员篇&顾问篇2023》。本篇挑选了近两年NetSuite知识会中的一些文章,涉及开发、权限、系统管理等较深的内容,共19…

Python 基于PyCharm断点调试

视频版教程 Python3零基础7天入门实战视频教程 PyCharm Debug(断点调试)可以帮助开发者在代码运行时进行实时的调试和错误排查,提高代码开发效率和代码质量。 准备一段代码 def add(num1, num2):return num1 num2if __name__ __main__:f…

使用stelnet进行安全的远程管理

1. telnet有哪些不足? 2.ssh如何保证数据传输安全? 需求:远程telnet管理设备 用户定义需要在AAA模式下: 开启远程登录的服务:定义vty接口 然后从R2登录:是可以登录的 同理R3登录: 在R1也可以查…

全国职业技能大赛云计算--高职组赛题卷①(容器云)

全国职业技能大赛云计算--高职组赛题卷①(容器云) 第二场次题目:容器云平台部署与运维任务1 Docker CE及私有仓库安装任务(5分)任务2 基于容器的web应用系统部署任务(15分)任务3 基于容器的持续…

驱动开发,IO模型,信号驱动IO实现过程

1.信号驱动IO框架图 分析: 信号驱动IO是一种异步IO方式。linux预留了一个信号SIGIO用于进行信号驱动IO。进程主程序注册一个SIGIO信号的信号处理函数,当硬件数据准备就绪后会发起一个硬件中断,在中断的处理函数中向当前进程发送一个SIGIO信号…

opencv形状目标检测

1.圆形检测 OpenCV图像处理中“找圆技术”的使用-图像处理-双翌视觉OpenCV图像处理中“找圆技术”的使用,图像处理,双翌视觉https://www.shuangyi-tech.com/news_224.htmlopencv 找圆心得,模板匹配比霍夫圆心好用 - 知乎1 相比较霍夫找直线算法, 霍夫找…

TCP详解之流量控制

TCP详解之流量控制 发送方不能无脑的发数据给接收方,要考虑接收方处理能力。 如果一直无脑的发数据给对方,但对方处理不过来,那么就会导致触发重发机制,从而导致网络流量的无端的浪费。 为了解决这种现象发生,TCP 提…

大语言模型之十-Byte Pair Encoding

Tokenizer 诸如GPT-3/4以及LlaMA/LlaMA2大语言模型都采用了token的作为模型的输入输出,其输入是文本,然后将文本转为token(正整数),然后从一串token(对应于文本)预测下一个token。 进入OpenAI官…

MFC 如何启用/禁用菜单(返灰/不可点击状态)

1、为页面(窗口)添加一个菜单栏和子菜单 2、在XXDlg.h文件中定义一个菜单栏变量和bool变量 CMenu m_Menu; //菜单变量 bool m_EnableMenu;//菜单栏中某个子菜单禁用/启用(变灰)的控制变量3、在OnInitDialog函数中进行初始化&…

常见的数码管中的引脚分布情况

简单介绍 数码管,实际就是用了7段亮的线段表示常见的数字或字符。常见的像下面几种(图片是网络中的截图)。事件中使用到的知识还是单片机中最基础的矩阵扫描。记得其中重要的有“余晖效应”,好像是要把不用的亮段关闭&#xff0c…

SpringBoot整合Flowable

1. 配置 &#xff08;1&#xff09; 引入maven依赖 <dependency><groupId>org.flowable</groupId><artifactId>flowable-spring-boot-starter</artifactId><version>6.7.2</version></dependency><!-- MySQL连接 -->&l…

【深度学习】 Python 和 NumPy 系列教程(十七):Matplotlib详解:2、3d绘图类型(3)3D条形图(3D Bar Plot)

目录 一、前言 二、实验环境 三、Matplotlib详解 1、2d绘图类型 2、3d绘图类型 0. 设置中文字体 1. 线框图 2. 3D散点图 3. 3D条形图&#xff08;3D Bar Plot&#xff09; 一、前言 Python是一种高级编程语言&#xff0c;由Guido van Rossum于1991年创建。它以简洁、易读…

左神高级提升班2 约瑟夫环结构

目录 【案例1】 【题目描述】 【输入描述&#xff1a;】 【输出描述&#xff1a;】 【输入】 【输出】 【思路解析】 【代码实现】 【案例1】 【题目描述】 某公司招聘&#xff0c;有n个人入围&#xff0c;HR在黑板上依次写下m个正整数A1、A2、……、Am&#xff0c;然后…

【重新定义matlab强大系列十三】直方图 bin 计数和分 bin 散点图

&#x1f517; 运行环境&#xff1a;Matlab &#x1f6a9; 撰写作者&#xff1a;左手の明天 &#x1f947; 精选专栏&#xff1a;《python》 &#x1f525; 推荐专栏&#xff1a;《算法研究》 #### 防伪水印——左手の明天 #### &#x1f497; 大家好&#x1f917;&#x1f91…