动手学深度学习:CNN和LeNet

前言

该篇文章记述从零如何实现CNN,以及LeNet对于之前数据集分类的提升效果。

从零实现卷积核

import torch
def conv2d(X,k):h,w=k.shapeY=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i,j]=(X[i:i+h,j:j+w]*k).sum()return Y
X=torch.tensor([[0.,1.,2.],[3.,4.,5.],[6.,7.,8.]])
k=torch.tensor([[0.,1.],[2.,3.]])
conv2d(X,k)

在这里插入图片描述

卷积层

from torch import nn
class Conv2D(nn.Module):def __init__(self,kernel_size):super.__init__()self.weight=nn.Parameter(torch.rand(kernel_size))self.bias=nn.Parameter(torch.zeros(1))def forward(self,x):return conv2d(x,self.weight)+self.bias

验证卷积层对于图像的检测作用

x=torch.ones((6,8))
x[:,2:6]=0
x

在这里插入图片描述

k=torch.tensor([[1.0,-1.0]])
y=conv2d(x,k)
y

在这里插入图片描述
很明显这个卷积核提取到了垂直上的特征

conv2d(x.t(),k)

在这里插入图片描述
并没有学习到水平特征

学习卷积核

我们可以让卷积核自己学习里面的参数以达到对不同图像提取的作用

conv2d=nn.Conv2d(1,1,kernel_size=(1,2),bias=False)x=x.reshape((1,1,x.shape[0],x.shape[1]))
y=y.reshape((1,1,6,7))
lr=3e-2for i in range(10):y_hat=conv2d(x)l=(y_hat-y)**2conv2d.zero_grad()l.sum().backward()conv2d.weight.data[:]-=lr*conv2d.weight.gradprint(f"第{i}轮,loss为{l.sum()}")

在这里插入图片描述

conv2d.weight.data

在这里插入图片描述

填充

def comp_conv2d(conv2d,x):#(1,1)添加batch大小和通道数x=x.reshape((1,1)+x.shape)y=conv2d(x)return y.reshape(y.shape[2:])
conv2d=nn.Conv2d(1,1,kernel_size=3,padding=1)
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

conv2d=nn.Conv2d(1,1,kernel_size=(5,3),padding=(2,1))
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

步幅

conv2d=nn.Conv2d(1,1,kernel_size=(3,3),padding=1,stride=2)
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

多通道

from d2l import torch as d2l
def corr2d_multi_in(X,K):return sum(d2l.corr2d(x,k) for x,k in zip(X,K))
x=torch.randn(size=(4,2,3))
k=torch.randn(size=(4,1,3))
corr2d_multi_in(x,k)

在这里插入图片描述

多输出通道

def corr2d_multi_in_out(X,K):return torch.stack([corr2d_multi_in(X,k)for k in K],0)
K=torch.stack((k,k+1,k+2),0)
K.shape

在这里插入图片描述

corr2d_multi_in_out(x,K)

在这里插入图片描述

1x1卷积

def corr2d_multi_in_out_1x1(X,K):c_i,h,w=X.shapec_o=K.shape[0]X=X.reshape((c_i,h*w))K=K.reshape((c_o,c_i))Y=torch.matmul(K,X)return Y.reshape((c_o,h,w))
X=torch.normal(0,1,(3,3,3))
K=torch.normal(0,1,(2,3,1,1))
Y1=corr2d_multi_in_out_1x1(X,K)
Y2=corr2d_multi_in_out(X,K)
Y1==Y2

在这里插入图片描述

汇聚层

def pool2d(x,pool_size,mode='max'):p_h,p_w=pool_sizeY=torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):if mode=='max':Y[i,j]=X[i:i+p_h,j:j+p_w].max()elif mode=='avg':Y[i,j]=X[i:i+p_h,j:j+p_w].mean()return Y
X=torch.tensor([[0.0,1.,2.],[3.,4.,5.],[6.,7.,8.]])
pool2d(X,(2,2))

在这里插入图片描述

pool2d(X,(2,2),'avg')

在这里插入图片描述

LeNet

这是最早的神经网络,根据我的测试,这个模型在我的数据集上的效果比MLP要提高了1%以上,在这段时间里面,我页发现了原有数据集在分类上存在问题,所以重新制作了一份,在这份数据集上,随着我数据量的提升以及模型的修改,准确率达到了99.7%,且无过拟合现象。

原始的LeNet

from torch import nn
net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,9))
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述

测试结果

我忘记截图了,效果达到了99%以上,同样的数据集在MLP上是98%

改进后的LeNet

我将平均池化层改成了最大池化层

net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,9))
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述

训练修改

我在训练过程中添加了记录test的loss最低时,保存pt和onnx,用于后续推理。

epochs_num=100
train_len=len(train_iter.dataset)
all_acc=[]
all_loss=[]
test_all_acc=[]
shape=None
for epoch in range(epochs_num):acc=0loss=0for x,y in train_iter:hat_y=net(x)l=loss_fn(hat_y,y)loss+=loptimer.zero_grad()l.backward()optimer.step()acc+=(hat_y.argmax(1)==y).sum()all_acc.append(acc/train_len)all_loss.append(loss.detach().numpy())test_acc=0test_loss=0test_len=len(test_iter.dataset)with torch.no_grad():for x,y in test_iter:shape=x.shapehat_y=net(x)test_loss+=loss_fn(hat_y,y)test_acc+=(hat_y.argmax(1)==y).sum()test_all_acc.append(test_acc/test_len)print(f'{epoch}的test的acc{test_acc/test_len}')# 保存测试损失最小的模型if test_loss < best_test_loss:best_test_loss = test_losstorch.save(net, best_model_path)dummy_input = torch.randn(shape)  torch.onnx.export(net, dummy_input, "./models/LeNet5.onnx", opset_version=11)print(f'Saved better model with Test Loss: {best_test_loss:.4f}')

在这里插入图片描述

损失函数可视化

plt.plot(range(1,epochs_num+1),all_loss,'.-',label='train_loss')
plt.text(epochs_num, all_loss[-1], f'{all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')

在这里插入图片描述

准确率可视化

plt.plot(range(1,epochs_num+1),all_acc,'-',label='train_acc')
plt.text(epochs_num, all_acc[-1], f'{all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_acc,'-.',label='test_acc')
plt.legend()

在这里插入图片描述

预测结果

import numpy as np
with torch.no_grad():all_num=5index=1plt.figure(figsize=(12,5))for i,label in zip(test_data_path,test_labels):if index<=all_num:img=cv2.imread(i)input_img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img=cv2.cvtColor(input_img,cv2.COLOR_BGR2RGB)input_img = np.expand_dims(input_img, axis=2)  # 增加通道维度,形状变为 [1, H, W]input_img=transforms.ToTensor()(input_img)input_img = input_img.unsqueeze(0)  # 增加批量维度,形状变为 [1, 1, 28, 20]print(input_img.shape)result=net(input_img).argmax(1)plt.subplot(1,all_num,index)plt.imshow(img)plt.title(f'true{label},predict{result.detach().numpy()}')plt.axis("off")index+=1

在这里插入图片描述

总结

数据集收集过程中遇到了部分麻烦,数据集还不够完整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/34248.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【开源代码解读】AI检索系统R1-Searcher通过强化学习RL激励大模型LLM的搜索能力

关于R1-Searcher的报告&#xff1a; 第一章&#xff1a;引言 - AI检索系统的技术演进与R1-Searcher的创新定位 1.1 信息检索技术的范式转移 在数字化时代爆发式增长的数据洪流中&#xff0c;信息检索系统正经历从传统关键词匹配到语义理解驱动的根本性变革。根据IDC的统计…

使用Node的http模块创建web服务,给客户端返回html页面时,css失效的根本原因(有助于理解http)

最近正在尝试使用node写后端&#xff0c;使用node创建http服务的时候&#xff0c;碰到了这样的一个问题&#xff1a; 这是我的源代码&#xff1a; import { createServer } from http import { join, dirname, extname } from path import { fileURLToPath } from url import…

JVM 2015/3/15

定义&#xff1a;Java Virtual Machine -java程序的运行环境&#xff08;java二进制字节码的运行环境&#xff09; 好处&#xff1a; 一次编写&#xff0c;到处运行 自动内存管理&#xff0c;垃圾回收 数组下标越界检测 多态 比较&#xff1a;jvm/jre/jdk 常见的JVM&…

IP风险度自检,互联网的安全“指南针”

IP地址就像我们的网络“身份证”&#xff0c;而IP风险度则是衡量这个“身份证”安全性的重要指标。它关乎着我们的隐私保护、账号安全以及网络体验&#xff0c;今天就让我们一起深入了解一下IP风险度。 什么是IP风险度 IP风险度是指一个IP地址可能暴露用户真实身份或被网络平台…

【鸿蒙】封装日志工具类 ohos.hilog打印日志

封装一个ohos.hilog打印日志 首先要了解hilog四大日志类型&#xff1a; info、debug、warm、error 方法中四个参数的作用 domain: number tag: string format: string ...args: any[ ] 实例&#xff1a; //普通的info日志&#xff0c;使用info方法来打印 //第一个参数 : 0x0…

走路碎步营养补充贴士

走路碎步&#xff0c;这种步伐不稳的现象&#xff0c;在日常生活中并不罕见&#xff0c;特别是对于一些老年人或身体较为虚弱的人来说&#xff0c;更是一种常见的行走状态。然而&#xff0c;这种现象可能不仅仅是肌肉或骨骼的问题&#xff0c;它还可能是身体在向我们发出营养缺…

Python软件和搭建运行环境

目录 一、Python安装全流程&#xff08;Windows/Mac/Linux&#xff09; 1. 下载官方安装包 2. 详细安装步骤&#xff08;以Windows为例&#xff09; 3. 环境变量配置&#xff08;Mac/Linux&#xff09; 二、虚拟环境管理&#xff08;关键&#xff01;&#xff09; 为什么需…

【蓝桥杯】省赛:神奇闹钟

思路 python做这题很简单&#xff0c;灵活用datetime库即可 code import os import sys# 请在此输入您的代码 import datetimestart datetime.datetime(1970,1,1,0,0,0) for _ in range(int(input())):ls input().split()end datetime.datetime.strptime(ls[0]ls[1],&quo…

RabbitMQ (Java)学习笔记

目录 一、概述 ①核心组件 ②工作原理 ③优势 ④应用场景 二、入门 1、docker 安装 MQ 2、Spring AMQP 3、代码实现 pom 依赖 配置RabbitMQ服务端信息 发送消息 接收消息 三、基础 work Queue 案例 消费者消息推送限制&#xff08;解决消息堆积方案之一&#…

HW基本的sql流量分析和wireshark 的基本使用

前言 HW初级的主要任务就是看监控&#xff08;流量&#xff09; 这个时候就需要我们 了解各种漏洞流量数据包的信息 还有就是我们守护的是内网环境 所以很多的攻击都是 sql注入 和 webshell上传 &#xff08;我们不管对面是怎么拿到网站的最高权限的 我们是需要指出它是…

camellia redis proxy v1.3.3对redis主从进行读写分离(非写死,自动识别故障转移)

1 概述 camellia-redis-proxy是一款高性能的redis代理&#xff08;https://github.com/netease-im/camellia&#xff09;&#xff0c;使用netty4开发&#xff0c;主要特性如下&#xff1a; 支持代理到redis-standalone、redis-sentinel、redis-cluster。支持其他proxy作为后端…

贪吃蛇小游戏-简单开发版

一、需求 本项目旨在开发一个经典的贪吃蛇游戏&#xff0c;用户可以通过键盘控制蛇的移动方向&#xff0c;让蛇吃掉随机出现在游戏区域内的食物&#xff0c;每吃掉一个食物&#xff0c;蛇的身体长度就会增加&#xff0c;同时得分也会相应提高。游戏结束的条件为蛇撞到游戏区域的…

使用 Docker 部署前端项目全攻略

文章目录 1. Docker 基础概念1.1 核心组件1.2 Docker 工作流程 2. 环境准备2.1 安装 Docker2.2 验证安装 3. 项目配置3.1 项目结构3.2 创建 Dockerfile 4. 构建与运行4.1 构建镜像4.2 运行容器4.3 访问应用 5. 使用 Docker Compose5.1 创建 docker-compose.yml5.2 启动服务5.3 …

接口自动化测试用例

Post接口自动化测试用例 Post方式的接口是上传接口&#xff0c;需要对接口头部进行封装&#xff0c;所以没有办法在浏览器下直接调用&#xff0c;但是可以用Curl命令的-d参数传递接口需要的参数。当然我们还以众筹网的登录接口为例&#xff0c;讲解post方式接口的自动化测试用…

使用WireShark解密https流量

概述 https协议是在http协议的基础上&#xff0c;使用TLS协议对http数据进行了加密&#xff0c;使得网络通信更加安全。一般情况下&#xff0c;使用WireShark抓取的https流量&#xff0c;数据都是加密的&#xff0c;无法直接查看。但是可以通过以下两种方法&#xff0c;解密抓…

阿里百炼Spring AI Alibaba

文章目录 学习链接阿里百炼创建api-key查看api调用示例示例pom.xmlAQuickStartMultiChatStreamChat Spring AI Alibaba简单示例pom.xmlapplication.ymlHelloworldControllerDashScopeChatModelController图解spring AI的结构 deepseekpom.xmlapplication.ymlDeepSeekChatClient…

【模拟算法】

目录 替换所有的问号 提莫攻击 Z 字形变换 外观数列 数青蛙&#xff08;较难&#xff09; 模拟算法&#xff1a;比葫芦画瓢。思路较简单&#xff0c;考察代码能力。 1. 模拟算法流程&#xff0c;一定要在演草纸上过一遍流程 2. 把流程转化为代码 替换所有的问号 1576. 替…

【Linux】进程(1)进程概念和进程状态

&#x1f31f;&#x1f31f;作者主页&#xff1a;ephemerals__ &#x1f31f;&#x1f31f;所属专栏&#xff1a;Linux 目录 前言 一、什么是进程 二、task_struct的内容 三、Linux下进程基本操作 四、父进程和子进程 1. 用fork函数创建子进程 五、进程状态 1. 三种重…

配置blender的python环境

在blender的脚本出输入&#xff1a; import sys print(sys.executable) 2. 通过上述命令我们得到blener的python版本&#xff0c;下面我们在conda配置一个同样版本的python环境。 conda create -n blenderpy python3.11.9找到blender安装路径下的python文件夹&#xff0c;将它…

【bug日记】 编译错误

在我使用vscode的时候&#xff0c;我想用一个头文件和两个cpp文件&#xff0c;头文件是用来声明一个类的&#xff0c;一个cpp是用来类的成员函数&#xff0c;一个cpp是主函数 但是我写完编译发现会弹出找不到这个类成员函数这个cpp文件&#xff0c;爆出这样的错误 提示我找不到…