LeNet-5

目录

一、知识点

二、代码

三、查看卷积层的feature map

1. 查看每层信息

​2. show_featureMap.py


背景:LeNet-5是一个经典的CNN,由Yann LeCun在1998年提出,旨在解决手写数字识别问题。

一、知识点

1. iter()+next()

iter():返回迭代器

next():使用next()来获取下一条数据

data = [1, 2, 3]
data_iter = iter(data)
print(next(data_iter))  # 1
print(next(data_iter))  # 2
print(next(data_iter))  # 3

2. enumerate

enumerate(sequence,[start=0]) 函数用于将一个可遍历的数据对象组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

start--下标起始位置的值。 

data = ['zs', 'ls', 'ww']
print(list(enumerate(data)))
# [(0, 'zs'), (1, 'ls'), (2, 'ww')]

3. torch.no_grad()

在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

当requires_grad设置为False时,在反向传播时就不会自动求导了,可以节约存储空间。

4. torch.max(input,dim)

input -- tensor类型

dim=0 -- 行比较

dim=1 -- 列比较

import torchdata = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
x = torch.max(data, dim=0)
print(x)
# values=tensor([7., 8., 9.]),
# indices=tensor([2, 2, 2])
x = torch.max(data, dim=1)
print(x)
# values=tensor([3., 6., 9.]),
# indices=tensor([2, 2, 2])

5. torch.eq:对两个张量Tensor进行逐个元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。

注意:item返回一个数。

import torchdata1 = torch.tensor([1, 2, 3, 4, 5])
data2 = torch.tensor([2, 3, 3, 9, 5])
x = torch.eq(data1, data2)
print(x)  # tensor([False, False,  True, False,  True])
sum = torch.eq(data1, data2).sum()
print(sum)  # tensor(2)
sum_item = torch.eq(data1, data2).sum().item()
print(sum_item)  # 2

6. squeeze(input,dim)函数

squeeze(0):若第一维度值为1,则去除第一维度

squeeze(1):若第二维度值为2,则去除第二维度

squeeze(-1):去除最后维度值为1的维度

7. unsqueeze(input,dim)

增加大小为1的维度,即返回一个新的张量,对输入的指定位置插入维度 1且必须指明维度。

二、代码

model.py

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)  # output(16,28,28)self.pool1 = nn.MaxPool2d(2, 2)  # output(16,14,14)self.conv2 = nn.Conv2d(16, 32, 5)  # output(32,10,10)self.pool2 = nn.MaxPool2d(2, 2)  # output(32,5,5)self.fc1 = nn.Linear(32 * 5 * 5, 120)  # output:120self.fc2 = nn.Linear(120, 84)  # output:84self.fc3 = nn.Linear(84, 10)  # output:10def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.view(-1, 32 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

train.py

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transformsfrom model import LeNetdef main():# preprocess datatransform = transforms.Compose([# Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]transforms.ToTensor(),# (mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 训练集 如果数据集已经下载了,则download=Falsetrain_data = torchvision.datasets.CIFAR10('./data', train=True, transform=transform, download=False)train_loader = torch.utils.data.DataLoader(train_data, batch_size=36, shuffle=True, num_workers=0)# 验证集val_data = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_data, batch_size=10000, shuffle=False, num_workers=0)# 返回迭代器val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# loop over the dataset multiple timesfor epoch in range(5):epoch_loss = 0for step, data in enumerate(train_loader, start=0):# get the inputs from train_loader;data is a list of[inputs,labels]inputs, labels = data# 在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算,因此需要对每个batch调用一遍zero_grad()将参数梯度设置为0optimizer.zero_grad()# 1.forwardoutputs = net(inputs)# 2.lossloss = loss_function(outputs, labels)# 3.backpropagationloss.backward()# 4.update x by optimizeroptimizer.step()# print statistics# 使用item()取出的元素值的精度更高epoch_loss += loss.item()# print every 500 mini-batchesif step % 500 == 499:with torch.no_grad():outputs = net(val_image)predict_y = torch.max(outputs, dim=1)[1]  # [0]取每行最大值,[1]取每行最大值的索引val_accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[epoch:%d step:%5d] train_loss:%.3f test_accuracy:%.3f' % (epoch + 1, step + 1, epoch_loss / 500, val_accuracy))epoch_loss = 0print('Train finished!')sava_path = './model/LeNet.pth'torch.save(net.state_dict(), sava_path)if __name__ == '__main__':main()

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),  # CHW格式transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']net = LeNet()net.load_state_dict(torch.load('./model/LeNet.pth'))image = Image.open('./predict/2.png')  # HWC格式image = transform(image)image = torch.unsqueeze(image, dim=0)  # 在第0维加一个维度 #[N,C,H,W] N:Batch批处理大小with torch.no_grad():outputs = net(image)predict = torch.max(outputs, dim=1)[1]print(classes[predict])if __name__ == '__main__':main()

2.png

 

三、查看卷积层的feature map

1. 查看每层信息

    for i in net.children():print(i)

2. show_featureMap.py

import torch
import torch.nn as nn
from model import LeNet
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as pltdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),  # CHW格式transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])image = Image.open('./predict/2.png')  # HWC格式image = transform(image)image = torch.unsqueeze(image, dim=0)  # 在第0维加一个维度 #[N,C,H,W] N:Batch批处理大小net = LeNet()net.load_state_dict(torch.load('./model/LeNet.pth'))conv_weights = []  # 模型权重conv_layers = []  # 模型卷积层counter = 0  # 模型里有多少个卷积层# 1.将卷积层以及对应权重放入列表中model_children = list(net.children())for i in range(len(model_children)):if type(model_children[i]) == nn.Conv2d:counter += 1conv_weights.append(model_children[i].weight)conv_layers.append(model_children[i])outputs = []names = []for layer in conv_layers[0:]:# 2.每个卷积层对image进行计算image = layer(image)outputs.append(image)names.append(str(layer))# 3.进行维度转换print(outputs[0].shape)  # torch.Size([1, 16, 28, 28]) 1-batch 16-channel 28-H 28-Wprint(outputs[0].squeeze(0).shape)  # torch.Size([16, 28, 28]) 去除第0维# 将16颜色通道的feature map加起来,变为一张28×28的feature map,sum将所有灰度图映射到一张print(torch.sum(outputs[0].squeeze(0), 0).shape)  # torch.Size([28, 28])processed_data = []for feature_map in outputs:feature_map = feature_map.squeeze(0)  # torch.Size([16, 28, 28])gray_scale = torch.sum(feature_map, 0)  # torch.Size([28, 28])# 取所有灰度图的平均值gray_scale = gray_scale / feature_map.shape[0]processed_data.append(gray_scale.data.numpy())# 4.可视化特征图figure = plt.figure()for i in range(len(processed_data)):x = figure.add_subplot(1, 2, i + 1)x.imshow(processed_data[i])x.set_title(names[i].split('(')[0])plt.show()if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/134914.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云无影云电脑介绍_云办公_使用_价格和优势说明

什么是阿里云无影云电脑?无影云电脑(原云桌面)是一种快速构建、高效管理桌面办公环境,无影云电脑可用于远程办公、多分支机构、安全OA、短期使用、专业制图等使用场景,阿里云百科分享无影云桌面的详细介绍、租用价格、…

【查缺补漏 女娲补天】2023平安

秋招了,只根据自己的情况记录,大概率不会很全。标题是我觉得的重点。既搬砖也搬博客。 Telnet协议 远程登录和管理网路设备的标准协议TCP传输层之上:应用层工作模型:C/S模式(client/server)服务端端口号默…

Python 数独求解器

文章目录 使用回溯算法在Python中解决数独总结 Sudoku(数独)是一种基于逻辑的数字填充谜题游戏,最受喜爱的是那些热爱逻辑和推理的人。解决数独谜题有助于提高集中注意力和逻辑思维能力。 本文介绍了如何使用Python解决数独谜题。 使用回溯算…

车联网远程监控管理提升车辆调度效率,实现高效运营

随着智慧城市建设与物联网技术发展,车辆使用4G工业路由器网络实现车联网,并对车上视频监控、GPS定位以及温湿度传感器等信息进行数据采集和实时传输。这些数据的采集和监测将通过4G网络上传到管理平台,为车辆调度和运行效率的优化提供了有力的…

计算机组成原理——基础入门总结(一)

本帖更新一些关于计算机组成原理的重点内容。由于博主考研时并不会考这门课,但是考虑到操作系统中又很多重要晦涩的概念涉及很多诸如内存、存储器、磁盘、cpu乃至各种寄存器的知识,此处挑选一些核心的内容总结复盘一遍——实现声明:本帖的内容…

【微信小程序】外卖点餐效果展示

概述 外卖点餐效果展示,左右布局,快速点餐,商家信息展示等...程序是模仿人家的,所以界面没做什么调整,功能是没啥问题,可以正常使用... 详细 直接看效果图: 可以把这个点餐这个功能分为5部分…

第13篇:ESP32 idf wifi联网使用SNTP同步网络时间LCD ST7920液晶屏显示

第1篇:Arduino与ESP32开发板的安装方法 第2篇:ESP32 helloword第一个程序示范点亮板载LED 第3篇:vscode搭建esp32 arduino开发环境 第4篇:vscodeplatformio搭建esp32 arduino开发环境 ​​​​​​第5篇:doit_esp32_devkit_v1使用pmw呼吸灯实验 第6篇:ESP32连接无源喇叭播…

看完这篇 教你玩转渗透测试靶机Vulnhub——Grotesque:2

Vulnhub靶机Grotesque:1.0.1渗透测试详解 Vulnhub靶机介绍:Vulnhub靶机下载:Vulnhub靶机安装:①:信息收集:②:暴力破解:③:SSH登入:④:提权&#…

day21算法

常见的七种查找算法: ​ 数据结构是数据存储的方式,算法是数据计算的方式。所以在开发中,算法和数据结构息息相关。今天的讲义中会涉及部分数据结构的专业名词,如果各位铁粉有疑惑,可以先看一下哥们后面录制的数据结构…

MapReduce YARN 的部署

1、部署说明 Hadoop HDFS分布式文件系统,我们会启动: NameNode进程作为管理节点DataNode进程作为工作节点SecondaryNamenode作为辅助 同理,Hadoop YARN分布式资源调度,会启动:ResourceManager进程作为管理节点NodeM…

uni-app 实现自定义按 A~Z 排序的通讯录(字母索引导航)

创建 convertPinyin.js 文件 convertPinyin.js 将下面的内容复制粘贴到其中 const pinyin (function() {let Pinyin function(ops) {this.initialize(ops);},options {checkPolyphone: false,charcase: "default"};Pinyin.fn Pinyin.prototype {init: functi…

JavaScript:二进制数组【笔记】

二进制数组【ArrayBuffer对象、Type的Array视图和DataView视图】JavaScript操作二进制数据的一个接口。 这些接口原本是和WebGL有关【WebGL是浏览器与显卡之间的通信接口】,为了满足JavaScript与显卡之间大量、实时数据交换,那么JavaScript和显卡之间的…

一款非常容易上手的报表工具,简单操作实现BI炫酷界面数据展示,驱动支持众多不同类型的数据库,可视化神器,免开源了

一款非常容易上手的报表工具,简单操作实现BI炫酷界面数据展示,驱动支持众多不同类型的数据库,可视化神器,免开源了。 在互联网数据大爆炸的这几年,各类数据处理、数据可视化的需求使得 GitHub 上诞生了一大批高质量的…

flink 端到端一致性

背景 我们经常会混淆flink提供的状态一致性保证和数据端到端一致性保证的关系,总以为他们表达的是同一个意思,事实上,他们不是一个含义,flink只能保证其维护的内部状态的一致性,而数据端到端的一致性需要数据源&#…

数据包络分析(DEA)

写在前面: 博主本人大学期间参加数学建模竞赛十多余次,获奖等级均在二等奖以上。为了让更多学生在数学建模这条路上少走弯路,故将数学建模常用数学模型算法汇聚于此专栏,希望能够对要参加数学建模比赛的同学们有所帮助。 目录 1. …

TypeScript 从入门到进阶之基础篇(一) ts类型篇

系列文章目录 文章目录 系列文章目录前言一、安装必要软件二、TypeScript 基础类型1.基础类型之 数字类型 number2.基础类型之 字符串类型 string3.基础类型之 布尔类型 boolean4.基础类型之 空值类型 void5.基础类型之 null 、undefined类型6.基础类型之 任意类型 any &#x…

第 363 场 LeetCode 周赛题解

A 计算 K 置位下标对应元素的和 模拟 class Solution { public:int pop_cnt(int x) {//求x的二进制表示中的1的位数int res 0;for (; x; x >> 1)if (x & 1)res;return res;}int sumIndicesWithKSetBits(vector<int> &nums, int k) {int res 0;for (int i…

CorelDRAW 2023怎么把图片转化为手绘素描风 简单几步轻松搞定

CorelDRAW 2023是一款非常好用的设计类软件&#xff0c;软件拥有非常多强大又好用的功能&#xff0c;可以帮助设计师快速创造出想要的效果&#xff0c;今天我们就来给大家介绍一下CDR的“素描”艺术笔触。它可以帮助用户快速将普通的图片快速转换成类似素描的风格&#xff0c;在…

接口测试——接口协议抓包分析与mock_L1

目录&#xff1a; 接口测试价值与体系常见的接口协议接口测试用例设计postman基础使用postman实战练习 1.接口测试价值与体系 接口测试概念 接口&#xff1a;不同的系统之间相互连接的部分&#xff0c;是一个传递数据的通道接口测试&#xff1a;检查数据的交换、传递和控制…

VUE build:gulp打包:测试、正式环境

目录 项目结构 Gulp VUE使用Gulp Vue安装Gulp Vue定义Gulp.js package.json build文件夹 config文件夹 static-config文件夹 项目结构 Gulp Gulp是一个自动化构建工具&#xff0c;可以帮助前端开发者通过自动化任务来管理工作流程。Gulp使用Node.js的代码编写&#xff…