简单搭建卷积神经网络实现手写数字10分类

搭建卷积神经网络实现手写数字10分类

1.思路流程

1.导入minest数据集

2.对数据进行预处理

3.构建卷积神经网络模型

4.训练模型,评估模型

5.用模型进行训练预测

一.导入minest数据集

 

MNIST--->raw--->test-->(0,1,2...) 10个文件夹

MNIST--->raw--->train-->(0,1,2...) 10个文件夹

共60000张图片.可自己去网上下载

二.对数据进行预处理

----读取图片,将图片先转为张量。

img=cv2.imread(path)

----将图片进行归一化,即将像素值标准化到0-1之间

img_tensor=util.transforms_train(img)

----裁剪,翻转等,实现数据增强。

数据增强:通过对原始图像进行旋转、翻转等操作,可以增加数据的多样性。这有助于模型学习到更具泛化性的特征,减少对特定方向或位置的依赖,从而提高模型的鲁棒性和准确性

transforms_train=transforms.Compose([# transforms.CenterCrop(10),# transforms.PILToTensor(),transforms.ToTensor(),#归一化,转tensortransforms.Resize((28,28)),transforms.RandomVerticalFlip()
])

ps:为什么要归一化

  1. 消除量纲影响:不同图像的像素值范围可能差异很大。归一化可以将像素值范围统一到一个特定的区间,例如 [0, 1] 或 [-1, 1],消除不同图像之间因像素值范围差异带来的影响,使模型更关注图像的特征和结构,而不是像素值的绝对大小。

  2. 提高训练稳定性:有助于优化算法的收敛性和稳定性。如果像素值范围较大且分布不均匀,可能导致梯度计算不稳定,从而影响模型的训练效率和效果。

  3. 缓解过拟合:一定程度上可以减少数据中的噪声和异常值对模型的影响,降低模型对某些特定像素值的过度依赖,从而提高模型的泛化能力,减少过拟合的风险。

三.构建卷积神经网络模型

常见卷积神经网络(CNN),主要由卷积,池化,全连接组成。卷积核在输入图像上滑动,通过卷积运算提取局部特征。卷积核在整个图像上重复使用,大大减少了模型的参数数量,降低了计算复杂度,同时也增强了模型对平移不变性的鲁棒性。池化层对特征进行压缩,提取主要特征,减少噪声和冗余信息。

x=torch.randn(2,3,28,28)

用x表示初始图形的信息。为了简单理解,简单表述。其中

2--->两张图片

3--->图片的通道数是3个,即 RGB

28,28--->图片的宽高是28px 28px

采用以上的神经网络conv为卷积操作,maxpool为池化。Linear为全连接。relu为激活函数。

进入全连接层时需要将展平。torch.Size([2, 16, 5, 5])--->torch.Size([2, 400])

x=torch.flatten(x,1)

因为全连接是只进行的线性的变化。所以要把每张图片的维数参数降为1。

使用print(summary(net, x))可查看网络的层次结构。其中-1就表示自己算,是多少张图片就是多少

输入的的是x=torch.randn(2,3,28,28),最终输出的是(2,10)

四.训练模型,评估模型

需要初始化之前的数据和网络,然后选择合适的优化器和损失函数,学习率和加载图片的批次去训练模型。使用loss_avg和accurary来评估模型的性能。对于pytorch来说优化器可以实现自动梯度清0,自动更新参数。我们需要主要的是就实现其中的维度的转化。loss越小越接近真实值。其中计算精度的方法使用one-hot编码。其中0表示[0,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],2表示[0,0,1,0,0,0,0,0,0,0].。。。其他依次类推。我们把用网络得出的参数,类似[0.1,0.2,0.1,0.5,0,0,0,0,0,0](数据我随便写的),然后用Python的argmax去处最大值的索引与one-hot真实值的索引相比,如果相等就是正确的结果。

----本次实验使用的是MSE损失函数

----lr(学习率)设为0.01

----使用的优化器Adam ,其实其他优化器你也可以随便试试。

Adam 算法的主要优点包括:

  1. 自适应学习率:能够为每个参数自适应地调整学习率。

  2. 偏差校正:在初始阶段对梯度估计进行校正,加速初期的学习速率。

  3. 适应性强:在很多不同的模型和数据集上都表现出良好的性能。

  4. 实现简单,计算高效,对内存需求少。

使用tensorboard进行可视化

五.用模型进行训练预测

需要读取之前训练好的模型,然后用这个模型来实现预测一个自己手写的图片

    # 加载整个模型loaded_model = torch.load('whole_model.pth')
​# 保存模型参数torch.save(loaded_model.state_dict(),'model_params.pth')

代码附上:

dataset.py

import glob
import os.path
​
import cv2
import torch
import util
​
class DataAndLabel:def __init__(self,path='D:\\0MNIST\\raw',is_train=True):super().__init__()# 拼接路径#data里面是path,labelclas='train' if is_train==True else 'test'path=os.path.join(path,clas)paths=glob.glob(os.path.join(path,'*','*'))# print(paths)# print(path)self.data=[]for path in paths:label=int(path.split('\\')[-2])self.data.append((path,label))def __getitem__(self, idx):#返回一个tensor,one-hotpath,label =self.data[idx]img=cv2.imread(path)# cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=util.transforms_train(img)one_hot=torch.zeros(10)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)
# if __name__ == '__main__':
#     data=DataAndLabel()
#     print(data[0])
#     print()

lenet5.py

import torch
import torch.nn as nn
from torchkeras import summary
class Net(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(3,6,5,1)self.maxpool1=nn.MaxPool2d(2)self.conv2=nn.Conv2d(6,16,3,1)self.maxpool2=nn.MaxPool2d(2)self.layer1=nn.Linear(16*5*5,10)self.layer2=nn.Linear(10,10)self.relu=nn.Softmax()def forward(self,x):x=self.conv1(x)x=self.relu(x)x=self.maxpool1(x)x=self.conv2(x)x=self.relu(x)x=self.maxpool2(x)# print(x.shape)x=torch.flatten(x,1)# print(x.shape)
​x=self.layer1(x)x=self.layer2(x)return x
if __name__ == '__main__':x=torch.randn(2,3,28,28)net=Net()out=net(x)# print(out.shape)# print(summary(net, x))

train_and_test

import torch
import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from lenet5 import Net
import torch.nn as nn
from dataset import DataAndLabel
class TrainAndTest(Dataset):def __init__(self):super().__init__()# self.writer=SummaryWriter("logs")net=Net()self.net=netself.loss=nn.MSELoss()self.opt = torch.optim.Adam(net.parameters(), lr=0.1)self.train_data=DataAndLabel(is_train=True)self.test_data=DataAndLabel(is_train=False)self.train_loader=DataLoader(self.train_data,batch_size=100,shuffle=False)self.test_loader=DataLoader(self.test_data,batch_size=100,shuffle=False)# 拿到数据,网络def train(self,epoch):loss_sum = 0accurary_sum = 0for img_tensor, label in tqdm.tqdm(self.train_loader, desc='train...', total=len(self.train_loader)):out = self.net(img_tensor)loss = self.loss(out, label)self.opt.zero_grad()loss.backward()self.opt.step()loss_sum += loss.item()accurary_sum += torch.mean(torch.eq(torch.argmax(label, dim=1), torch.argmax(out, dim=1)).to(torch.float32)).item()loss_avg = loss_sum / len(self.train_loader)accurary_avg = accurary_sum / len(self.train_loader)print(f'train---->loss_avg={round(loss_avg, 3)},accurary_avg={round(accurary_avg, 3)}')# self.writer.add_scalars('loss',{'loss_avg':loss_avg},epoch)def train1(self):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm.tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy =torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
​
​def run(self):for epoch in range(10):self.train1()# self.test(epoch)
if __name__ == '__main__':tt=TrainAndTest()tt.run()

util.py

from torchvision import transforms
​
transforms_train=transforms.Compose([# transforms.CenterCrop(10),# transforms.PILToTensor(),transforms.ToTensor(),#归一化,转tensortransforms.Resize((28,28)),transforms.RandomVerticalFlip()
])
transforms_test=transforms.Compose([transforms.ToTensor(),  # 归一化,转tensortransforms.Resize((28, 28)),
])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/378904.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在pycharm 2023.2.1中运行由R语言编写的ipynb文件

在pycharm 2023.2.1中运行由R语言编写的ipynb文件 背景与目标: 项目中包含由R语言编写的ipynb文件,希望能在pycharm中运行该ipynb文件。 最终实现情况: 未能直接在pycharm中运行该ipynb文件,但是替代的实现方法有:…

百度网盘Android一二面凉经(2024)

百度网盘Android一二面凉经(2024) 笔者作为一名双非二本毕业7年老Android, 最近面试了不少公司, 目前已告一段落, 整理一下各家的面试问题, 打算陆续发布出来, 供有缘人参考。今天给大家带来的是《百度网盘Android一二面凉经(2024)》。 面试职位: 网盘主端研发组_Android高级研…

MyPostMan 迭代文档管理、自动化接口闭环测试工具(自动化测试篇)

MyPostMan 是一款类似 PostMan 的接口请求软件,按照 项目(微服务)、目录来管理我们的接口,基于迭代来管理我们的接口文档,文档可以导出和通过 url 实时分享,按照迭代编写自动化测试用例,在不同环…

CSS3实现提示工具的渐入渐出效果及CSS3动画简介

上一篇文章用CSS3实现了一个提示工具,本文介绍如何利用CSS3实现提示工具以渐入的方式呈现,以渐出的方式消失。 CSS3主要可以通过两个样式来实现动画效果:animation和transition。 其中,animation需要自己定义一组关键帧从而实现…

在 Navicat BI 创建自定义字段:类型更改字段

早在 Navicat 17 的预览版中,我们就已经介绍了一些新的商业智能(BI)功能,即图表互动和计算字段。需要说明的是,计算字段不是 Navicat BI 中唯一可用的自定义字段类型。事实上,有五种:类型改变、…

@google/model-viewer 导入 改纹理 (http-serve)

导入模型 改纹理 效果图 <template><div><h1>鞋模型</h1><model-viewerstyle"width: 300px; height: 300px"id"my-replace-people"src"/imgApi/Astronaut.glb"auto-rotatecamera-controls></model-viewer>&…

C++——函数模板和类模板

目录 一、函数模板 二、类模板 一、函数模板 当我们没有使用到模板的时候&#xff0c;我们如果要交换两个数据&#xff0c;那么我们就要根据交换的数据的类型&#xff0c;写出例如以下的函数&#xff1a; void Swap(int& a, int& b) {int tmp a;a b;b tmp; }void S…

HardeningMeter:一款针对二进制文件和系统安全强度的开源工具

关于HardeningMeter HardeningMeter是一款针对二进制文件和系统安全强度的开源工具&#xff0c;该工具基于纯Python开发&#xff0c;经过了开发人员的精心设计&#xff0c;可以帮助广大研究人员全面评估二进制文件和系统的安全强化程度。 功能特性 其强大的功能包括全面检查各…

appium2.0 执行脚本遇到的问题

遇到的问题&#xff1a; appium 上的日志信息&#xff1a; 配置信息 方法一 之前用1.0的时候 地址默认加的 /wd/hub 在appium2.0上&#xff0c; 服务器默认路径是 / 如果要用/wd/hub 需要通过启动服务时设置基本路径 appium --base-path/wd/hub 这样就能正常执行了 方法二…

react基础样式控制

行内样式 <div style{{width:500px, height:300px,background:#ccc,margin:200px auto}}>文本</div> class类名 注意&#xff1a;在react中使用class类名必须使用className 在外部src下新建index.css文件写入你的样式 .fontcolor{color:red } 在用到的页面引入…

C#学习-刘铁猛

文章目录 1.委托委托的具体使用-魔板方法回调方法【好莱坞方法】&#xff1a;通过委托类型的参数&#xff0c;传入主调方法的被调用方法&#xff0c;主调方法可以根据自己的逻辑决定调用这个方法还是不调用这个方法。【演员只用接听电话&#xff0c;如果通过&#xff0c;导演会…

STM32使用Wifi连接阿里云

目录 1 实现功能 2 器件 3 AT指令 4 阿里云配置 4.1 打开阿里云 4.2 创建产品 4.3 添加设备 5 STM32配置 5.1 基础参数 5.2 功能定义 6 STM32代码 本文主要是记述一下&#xff0c;如何使用阿里云物联网平台&#xff0c;创建一个简单的远程控制小灯示例。 完整工程&a…

vue、js截取视频任意一帧图片

html有本地上传替换部分&#xff0c;可以不看 原理&#xff1a;通过video标签对视频进行加载&#xff0c;随后使用canvas对截取的视频帧生成需要的图片 <template> <el-row :gutter"18" class"preview-video"><h4>视频预览<span&…

【概率论三】参数估计:点估计(矩估计、极大似然法)、区间估计

文章目录 一. 点估计1. 矩估计法2. 极大似然法2.1. 似然函数2.2. 极大似然估计法 3. 评价估计量的标准3.1. 无偏性3.2. 有效性3.3. 一致性 二. 区间估计1. 区间估计的概念2. 正态总体参数的区间估计 参数估计讲什么 由样本来确定未知参数参数估计分为点估计与区间估计 一. 点估…

[iOS]浅析isa指针

[iOS]浅析isa指针 文章目录 [iOS]浅析isa指针isa指针isa的结构isa的初始化注意事项 上一篇留的悬念不止分类的实现 还有isa指针到底是什么 它是怎么工作的 class方法又是怎么运作的 class_data_bits_t bits; // class_rw_t * plus custom rr/alloc flags 这里面的class又是何方…

新华三H3CNE网络工程师认证—VLAN使用场景与原理

通过华三的技术原理与VLAN配置来学习&#xff0c;首先介绍VLAN&#xff0c;然后介绍VLAN的基本原理&#xff0c;最后介绍VLAN的基本配置。 一、传统以太网问题 在传统网络中&#xff0c;交换机的数量足够多就会出现问题&#xff0c;广播域变得很大&#xff0c;分割广播域需要…

R语言优雅的把数据基线表(表一)导出到word

基线表&#xff08;Baseline Table&#xff09;是医学研究中常用的一种数据表格&#xff0c;用于在研究开始时呈现参与者的初始特征和状态。这些特征通常包括人口统计学数据、健康状况和疾病史、临床指标、实验室检测、生活方式、社会经济等。 本人在既往文章《scitb包1.6版本发…

C++客户端Qt开发——QT初识

二、QT初识 1.helloworld示例 ①图形化的方式&#xff0c;在界面上创建出一个控件&#xff0c;显示helloworld 右侧通过树形结构&#xff0c;就会显示出当前界面上有哪些控件 此时.ui文件已发生变化 qmake就会在编译项目的时候&#xff0c;基于这个内容&#xff0c;生成一段C…

35.UART(通用异步收发传输器)-RS232(2)

&#xff08;1&#xff09;RS232接收模块visio框图&#xff1a; &#xff08;2&#xff09;接收模块Verilog代码编写: /* 常见波特率&#xff1a; 4800、9600、14400、115200 在系统时钟为50MHz时&#xff0c;对应计数为&#xff1a; (1/4800) * 10^9 /20 -1 10416 …

链接追踪系列-10.mall-swarm微服务运行并整合elk-上一篇的番外

因为上一篇没对微服务代码很详细地说明&#xff0c;所以在此借花献佛&#xff0c;使用开源的微服务代码去说明如何去做链路追踪。 项目是开源项目&#xff0c;fork到github以及gitee中&#xff0c;然后拉取到本地 后端代码&#xff1a; https://gitee.com/jelex/mall-swarm.gi…