01——LenNet网络结构,图片识别

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。# 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的def forward(self, x):   # 正向传播# Pytorch Tensor的通道排序:[channel,height,width]'''卷积后的尺寸大小计算:N = (W-F+2P)/S + 1其中,默认的padding:0   stride:1①输入图片大小:WxW②Filter大小 FxF  (卷积核大小)③步长S④padding的像素数P'''x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1表示第一个维度batch需要自动推理x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transformsfrom pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose([transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',train=True,transform=transform,download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',train=False,transform=transform,download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')'''标准化处理:output = (input - 0.5) / 0.5反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮running_loss = 0.0  # 叠加,训练过程中的损失for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本inputs, labels = data   # 获取图像及其对应的标签optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签loss.backward()optimizer.step()running_loss += loss.item()if step % 500 == 499:with torch.no_grad():  # with 是一个上下文管理器outputs = net(test_image)  # [batch,10]predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0
print('Finished Training')save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

测试下展示图片:

运行下,train.py文件,看下正确率、损失率:

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNettransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。with torch.no_grad():  # 不需要计算损失梯度outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;# torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。# .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()# .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/277233.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Navicat破解 Navicat下载安装 附教程 免费

百度网盘:https://pan.baidu.com/s/1wRRN_18_uXxPiIWCS4l43A 麻烦各位师傅帮忙填写一下问卷,提取码在问卷填写结束后显示~ 【https://www.wjx.cn/vm/mBBTTKm.aspx# 】 (资料来源于网络,侵告删)

ThreeJs 可视化大屏地图

效果图: 今天给各位提供一个可视化地图的案例和源码,关注下吧各位!回复"大屏地图" 获取工程 1、案例分析:主要使用了UI 美工图片,然后获取地图的json 数据绘制图形,贴图使用该区域的地图纹理进行…

git push解决办法:! [remote rejected] prod -> prod (pre-receive hook declined)

今天想把最近改的东西上传到Gogs上发版一下子的,但是发现有冲突合并不了,于是我切回自己的分支合并了prod,把冲突处理了一下子,还又增加了一点修改,push后.......又回到prod进行git push,哦豁~这就出了问题…

编写测试用例的方法,这个是真的很好用

大家测试过程中经常用的等价类划分、边界值分析、场景法等,并不能覆盖所有的需求,我们之前讲过很少用到的因果图法,下面就来讲另一种不经常用到但又非常重要的测试用例编写方法——测试大纲法。 测试大纲法适用于有多个窗口,每个…

mysql主从复制/主从备份搭建

mysql主从复制/主从备份搭建 前言一、主从复制1)为什么使用主从复制、读写分离?2)主从复制原理 二、如何实现主从复制?1)主库配置1、修改配置文件2、登录mysql: 2)从库配置1、修改配置文件2、登…

前端之用HTML弄一个古诗词

将进酒 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>将进酒</title><h1><big>将进酒</big> 君不见黄河之水天上来</h1><table><tr><td ><img…

复现文件上传漏洞

一、搭建upload-labs环境 将下载好的upload-labs的压缩包&#xff0c;将此压缩包解压到WWW中&#xff0c;并将名称修改为upload&#xff0c;同时也要在upload文件中建立一个upload的文件。 然后在浏览器网址栏输入&#xff1a;127.0.0.1/upload进入靶场。 第一关 选择上传文件…

浅易理解:卷积神经网络(CNN)

浅易理解卷积神经网络流程 本文的目录&#xff1a; 1 什么卷积神经网络 2 输入层 3 卷积层 4 池化层 5 全连接层 传统的多层神经网络只有 输入层、隐藏层、输出层 卷积神经网络&#xff08;CNN)&#xff1a; 在多层神经网络的基础上&#xff0c;加入了更加有效的特征学习部分…

spacy进行简单的自然语言处理的学习

自然语言处理基本概念 概念&#xff1a;自然语言处理&#xff0c;是让机器理解人的语言的过程。 作用&#xff1a;通过使用自然语言处理&#xff0c;机器可以理解人的语言&#xff0c;从而进行语义分析&#xff0c;例如&#xff1a;从一句话中判断喜怒哀乐&#xff1b;从一段文…

汽车电子零部件(3):ADAS前视感知系统FLC

前言: 比如车道保持和车道改变这种场景,如何进行车道的识别,如何进行周围车辆的识别,这算是ADAS中的一个场景,其中就会用到FLC前视感知系统。还有比如前向物体识别,前向车辆识别等。 再往大里说那就是车联网了: 除了前向也可能有其他部位

opencv安装(C++)并配置vs

准备工作&#xff1a; 1.opencv安装包(此教程使用4.9) 2.visual studio(此教程使用vs2019) opencv安装&#xff1a; 1、下载opencv&#xff1a; 1.1 官网下载&#xff1a;Releases - OpenCV 1.2 百度网盘&#xff1a;链接&#xff1a;https://pan.baidu.com/s/1NpEoFjbbyQJtFD…

【Java】List, Set, Queue, Map 区别?

目录 List, Set, Queue, Map 区别&#xff1f; Collection和Collections List ArrayList 和 Array区别&#xff1f; ArrayList与LinkedList区别? ArrayList 能添加null吗&#xff1f; ArrayList 插入和删除时间复杂度&#xff1f; LinkedList 插入和删除时间复杂度&…

【C++】string进一步介绍

个人主页 &#xff1a; zxctscl 如有转载请先通知 文章目录 1. 前言2. 迭代器2.1 反向迭代器2.2 const对象迭代器 3. Capacity3.1 size和length3.2 max_size3.3 capacity3.4 clear3.5 shrink_to_fit &#xff08;了解即可&#xff09;3.6 reserve3.7 resize 4. Element access4…

Java面向对象案例之描述专业和学生(4)

类的方法图 学生类&#xff1a; 属性&#xff1a;学号&#xff0c;姓名&#xff0c;年龄&#xff0c;所学习的专业方法&#xff1a;学习的方法&#xff0c;描述学习状态。描述内容包括姓名、学号、年龄、所学习的专业信息 专业类&#xff1a; 属性&#xff1a;专业编号&#xf…

Vue 3中的reactive:响应式状态的全面管理

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

【计算机视觉】二、图像形成——实验:2D变换编辑(Pygame)

文章目录 一、向量和矩阵的基本运算二、几何基元和变换1、几何基元(Geometric Primitives)2、几何变换(Geometric Transformations)2D变换编辑器0. 程序简介环境说明程序流程 1. 各种变换平移变换旋转变换等比缩放变换缩放变换镜像变换剪切变换 2. 按钮按钮类创建按钮 3. Pygam…

基于PHP的数字化档案管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的数字化档案管理系统 一 介绍 此数字化档案管理系统基于原生PHP&#xff0c;MVC架构开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈 php(mvc)mysqlbootstrapphpstudyvscode 二 功能 …

SwiftUI的Alert使用方式

SwiftUI的Alert使用方式 记录一下SwiftUI的Alert使用方式&#xff0c;比较简单直接上代码 import SwiftUIstruct AlertBootCamp: View {State var showAlert falsevar body: some View {Button {showAlert.toggle()} label: {Text("alert show")}/// 单按钮 // …

python知识点总结(一)

这里写目录标题 一、什么是WSGI,uwsgi,uWSGI1、WSGI2、uWSGI3、uwsgi 二、python中为什么没有函数重载&#xff1f;三、Python中如何跨模块共享全局变量?四、内存泄露是什么?如何避免?五、谈谈lambda函数作用?六、写一个函数实现字符串反转&#xff0c;尽可能写出你知道的所…

面试官:volatile如何保证可见性的,具体如何实现?

写在开头 在之前的几篇博文中&#xff0c;我们都提到了 volatile 关键字&#xff0c;这个单词中文释义为&#xff1a;不稳定的&#xff0c;易挥发的&#xff0c;在Java中代表变量修饰符&#xff0c;用来修饰会被不同线程访问和修改的变量&#xff0c;对于方法&#xff0c;代码…