pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)

目录

LeNet-5

LeNet-5 结构

CIFAR-10

pytorch实现

lenet模型

训练模型

1.导入数据

2.训练模型

3.测试模型

测试单张图片

代码

运行结果


LeNet-5

LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。

LeNet-5 结构

LeNet-5 的结构包括以下几个层次:

  1. 输入层: 32x32 的灰度图像。
  2. 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
  3. 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
  4. 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
  5. 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
  6. 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
  7. 全连接层 F6: 包含 84 个神经元。
  8. 输出层: 包含 10 个神经元,对应于 10 个类别。

CIFAR-10

CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。

1. 标注数据量训练集:50000张图像测试集:10000张图像

2. 标注类别数据集共有10个类别。具体分类见图1。

3. 可视化

pytorch实现

lenet模型

  • 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
  • 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。

在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。

import torch.nn as nn
import torch.nn.functional as funcclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=5)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = func.relu(self.conv1(x))x = func.max_pool2d(x, 2)x = func.relu(self.conv2(x))x = func.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = func.relu(self.fc1(x))x = func.relu(self.fc2(x))x = self.fc3(x)return x

训练模型

1.导入数据

导入训练数据和测试数据

    def load_data(self):#transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])test_transform = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)# shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

2.训练模型

    def train(self):print("train:")self.model.train()train_loss = 0train_correct = 0total = 0for batch_num, (data, target) in enumerate(self.train_loader):data, target = data.to(self.device), target.to(self.device)self.optimizer.zero_grad()output = self.model(data)loss = self.criterion(output, target)loss.backward()self.optimizer.step()train_loss += loss.item()prediction = torch.max(output, 1)  # second param "1" represents the dimension to be reducedtotal += target.size(0)# train_correct incremented by one if predicted righttrain_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))return train_loss, train_correct / total

3.测试模型

    def test(self):print("test:")self.model.eval()test_loss = 0test_correct = 0total = 0with torch.no_grad():for batch_num, (data, target) in enumerate(self.test_loader):data, target = data.to(self.device), target.to(self.device)output = self.model(data)loss = self.criterion(output, target)test_loss += loss.item()prediction = torch.max(output, 1)total += target.size(0)test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))return test_loss, test_correct / total

测试单张图片

网上随便下载一个图片

然后使用图片编辑工具,把图片设置为32x32大小

通过导入模型,然后测试一下

代码

import torch
import cv2
import torch.nn.functional as F
#from model import Net  ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('lenet.pth')  # 加载模型model = model.to(device)model.eval()  # 把模型转为test模式img = cv2.imread("bird1.png")  # 读取要预测的图片trans = transforms.Compose([transforms.ToTensor()])img = trans(img)img = img.to(device)img = img.unsqueeze(0)  # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 扩展后,为[1,1,28,28]output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print(prob)value, predicted = torch.max(output.data, 1)print(predicted.item())print(value)pred_class = classes[predicted.item()]print(pred_class)

运行结果

tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird

从结果看,效果还不错。记录一下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/399582.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Solana真假繁荣调查:机器人横行占7成交易,Meme数据下滑严重

随着Solana链上数据全面赶超以太坊&#xff0c;关于Solana将超越以太坊的讨论逐渐升温。然而&#xff0c;与此同时&#xff0c;关于Solana真实繁荣度的争议也引发了广泛关注。本文将深入探讨Solana生态中机器人泛滥、交易量数据虚高、MEV问题、财务亏损问题&#xff0c;以及SOL…

从零搭建xxl-job(四):xxljob进行一些性能优化

之前的代码这部分并没有补充完毕&#xff0c;假如调度中心如果判断有定时任务要执行了&#xff0c;该怎么远程通知给执行定时任务的程序呢&#xff1f;当定时任务要把自己的信息发送给调度中心时&#xff0c;是通过一个RegistryParam对象发送的。该对象内部封装了定时任务相关的…

【C#】explicit、implicit与operator

字面解释 explicit&#xff1a;清楚明白的;易于理解的;(说话)清晰的&#xff0c;明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit&#xff1a;含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator&#xff1a;操作人员;技工;电话员;接线员;…

HarmonyOS应用开发者高级认证(一)

1、依次点击A、B、C、D四个按钮&#xff0c;其中不会触发UI刷新的是&#xff1a; 答案&#xff1a; Button("C").onClick(() > {this.nameList[0].name "Jim"})分析&#xff1a;直接更新非一级数据不会触发UI刷新 2、如果要实现Row组件内的子元素均匀…

基于JSP的个性化影片推荐系统

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a;JSP 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;MyEclipse、Tomcat、MySQL 系统展示 首页 管理员功能模块 用户功能模块 …

Rancher的RKE和RKE2部署K8s集群kube-proxy开启strictARP

kube-proxy配置strictARPtrue 1、非RKE部署的K8s集群&#xff1a;配置首先&#xff0c;需要为kube-proxy启动strictARP&#xff0c;以便Kubernetes集群中的所有网卡停止响应其他网卡的ARP请求&#xff0c;而由OpenELB来处理ARP请求。 $ kubectl edit configmap kube-proxy -n…

C# 在Word中插入或删除分节符

在Word中&#xff0c;分节符是一种强大的工具&#xff0c;用于将文档分成不同的部分&#xff0c;每个部分可以有独立的页面设置&#xff0c;如页边距、纸张方向、页眉和页脚等。正确使用分节符可以极大地提升文档的组织性和专业性&#xff0c;特别是在长文档中&#xff0c;需要…

【STM32】USART通用同步/异步收发器(串口数据的接收与发送)

本篇博客重点在于标准库函数的理解与使用&#xff0c;搭建一个框架便于快速开发 目录 USART简介 USART时钟使能 USART初始化 串口参数 串口数据时序 USART中断配置 USART使能 数据的接收与发送 Serial.h Serial.c main.c USART简介 USART&#xff08;Universal S…

leedCode - - - 栈和队列

目录 1.有效的括号&#xff08; LeetCode 20 &#xff09; 2.最小栈&#xff08; LeetCode 155 &#xff09; 3.接雨水&#xff08; LeetCode 42 &#xff09; 4.逆波兰表达式求值&#xff08;LeetCode 150&#xff09; 5.柱状图中最大的矩形&#xff08;LeetCode 84&…

计算机毕业设计选题推荐-大学生就业招聘管理系统-Java/Python项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

【Redis】Redis 初探:特性、应用场景与高并发架构演进之路

目录 初识 Redis关于 Redis服务端高并发分布式结构演进之路概述常⻅概念基本概念应⽤&#xff08;Application&#xff09;/ 系统&#xff08;System&#xff09;模块&#xff08;Module&#xff09;/ 组件&#xff08;Component&#xff09;分布式&#xff08;Distributed&…

SSM养老院信息管理系统—计算机毕业设计源码16963

目 录 摘要 1 绪论 1.1研究意义 1.2开发意义 1.3ssm框架介绍 1.4论文结构与章节安排 2 养老院信息管理系统系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.…

C++ STL初阶(9):list 中关于reverse_iterator的实现

在完成vector和list的iterator相关部分的实践后来完成反向迭代器的实现 1. list的反向迭代器 书接上回&#xff0c;反向迭代器应当重新封装一个类。 反向迭代器和正向迭代器最大的区别就是&#xff0c;反向迭代器是倒着走的&#xff0c;所以最核心的逻辑就是将封装成-- 注意&am…

Cadence Allegro 入门教程笔记:如何绘制原理图和原理图库?

文章目录 一、用 Capture CIS 17.4 绘制原理图库 Cadence Allegro QQ交流学习裙&#xff1a;173416628 1、凡亿教育的Cadence Allegro 17.4基础教程 2、小哥Cadence Allegro 132讲 技巧视频 3、小哥Cadence Allegro 两层板 基础视频 4、小哥Cadence Allegro 四层板 提高视频…

【NLP】文本处理的基本方法【jieba分词、命名实体、词性标注】

文章目录 1、本章目标2、什么是分词3、jieba的使用3.1、精确模式分词3.2、全模式分词3.3、搜索引擎模式分词3.4、中文繁体分词3.5、使用用户自定义词典 4、什么是命名实体识别5、什么是词性标注6、小结7、jieba词性对照表⭐ &#x1f343;作者介绍&#xff1a;双非本科大三网络…

opencv-python图像增强三:图像清晰度增强

文章目录 一、简介&#xff1a;二、图像清晰度增强方案&#xff1a;三、算法实现步骤3.1高反差保留实现3.2. usm锐化3.3 Overlay叠加 四&#xff1a;整体代码实现五&#xff1a;效果 一、简介&#xff1a; 你是否有过这样的烦恼&#xff0c;拍出来的照片总是不够清晰&#xff…

【Linux】网络编程套接字Scoket:UDP网络编程

目录 一、了解UDP协议 二、了解端口和IP地址 三、套接字概述与Socket的概念 四、Socket的类型 五、 Socket的信息数据结构 六、网络字节序与主机字节序的互相转换 七、地址转换函数 八、UDP网络编程流程及相关函数 socket函数 bind函数 recvfrom函数 sendto函数 …

UIAbility组件基础(一)

一、概述 UIAbility组件是一种包含UI的应用组件&#xff0c;主要用于和用户交互。UIAbility组件是系统调度的基本单元&#xff0c;为应用提供绘制界面的窗口。一个应用可以包含一个或多个UIAbility组件。每一个UIAbility组件实例都会在最近任务列表中显示一个对应的任务。 U…

C语言 ——— 学习、使用memmove函数 并模拟实现

目录 memmvoe函数的功能 学习memmove函数​编辑 模拟实现memmove函数 memmvoe函数的功能 memmvoe函数的功能类似于memcpy函数&#xff0c;都是内存拷贝&#xff0c;唯一的区别是memcpy函数不能成功拷贝原数据&#xff0c;而memmvoe函数可以 举例来说&#xff1a; [1, 2, 3…

单元测试注解:@ContextConfiguration

ContextConfiguration注解 ContextConfiguration注解主要用于在‌Spring框架中加载和配置Spring上下文&#xff0c;特别是在测试场景中。 它允许开发者指定要加载的配置文件或配置类的位置&#xff0c;以便在运行时或测试时能够正确地构建和初始化Spring上下文。 基本用途和工…