keras深度学习框架构建LeNet5神经网络模型实现手写数字识别

    之前两篇文章分别通过keras深度学习框架构建简单神经网络和卷积神经网络实现过手写数字识别实验。这篇文章分享我根据LeNet5模型构建的卷积神经网络来实现手写数字识别。

     这个实验是根据LeNet5模型构建卷积神经网络,LeNet5模型的原理图如下所示:

     相信大家在很多地方见过这个模型,但是实际上,很多代码里面并没有严格按照这个模型来做实验。为什么?主要是手写数字识别模型数据mnist规格是28*28,而这个LeNet5模型需要的数据输入是32*32。大家为了凑合,就私自改了这个模型的参数,最后输入变为了28*28,然后经历第一次卷积之后,特征图就变成了6@24*24 ,下采样再缩小一半:6@12*12。到了全连接层的时候,特征图就变为了16@4*4。

    其实这样修改,没有什么毛病,最后也能达到实验的目的,而且训练模型之后,测试准确率可以达到98%以上。

     真正要按照这个LeNet5模型就需要32*32输入形状,而现如今好像只有KNN手写数字识别那个实验好像有32*32=1024规格的训练数据集trainingDigits和测试数据集testDigits,而且那个数据量比较少,训练数据集2000个,测试数据集900多个,另一个就是它的数据是文本,要得到那个数据,还需要读文本转换。

    有一个折中的办法,就是我们可以通过opencv提供的resize方法把mnist数据集的28*28形状转为32*32形状,可能就是会丢失一些精度,而现在测试数据和用来预测图片数据也一样会修改,所以精度问题可以认为抵消。

    其实本身这个模型也有一些模糊的地方,就是这图里面并没有给出下采样层的具体方式,是使用最大值采样,还是平均值采样,还有卷积层的激活函数,只有细看论文才能找到答案。

     所以在实现中,卷积层的激活函数一般使用tanh,但是使用relu也没有问题,同理,下采样层SubSampling这里采用MaxPool2D,AveragePooling2D都行。

    简单说一下这个模型:

    输入图片矩阵形状 (32,32)

    第一层卷积层:6个卷积核,卷积核大小5*5,所以输出特征图(feature maps)大小就是6@(32-5+1) * (32-5+1) = 6@28*28

    第二层下采样(SubSampling):2*2大小,结果就是特征图大小继续减半6@14*14

    第三层卷积层:16个卷积核,卷积核大小5*5,输出特征图大小:16@(14-5+1)*(14-5+1)=16@10*10

   第四层下采样:2*2大小,特征图减小一半16@5*5

    第五层展平层:120个神经元

    第六层全连接层:84个神经元

    第七层输出层:10个神经元

    下面根据mnist数据集以及LeNet5数据模型搭建神经网络并训练测试,代码如下:

from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, AveragePooling2D
from keras.layers import Dense, Flatten
import keras
from keras.datasets import mnist
from keras.utils import np_utils
from keras.utils.vis_utils import plot_model
import numpy as np
import cv2# 加载数据
(X_train, y_train), (X_test, y_test) = mnist.load_data()
input_shape = (32, 32, 1)
train_x = []test_x = []for val in X_train:img = cv2.resize(val, (input_shape[0], input_shape[1]))train_x.append(img.reshape(input_shape))for val_ in X_test:img = cv2.resize(val_, (input_shape[0], input_shape[1]))test_x.append(img.reshape(input_shape))# 数据预处理
X_train = np.array(train_x) / 255.0
X_test = np.array(test_x) / 255.0# to_categorical()将类别向量转换为二进制(只有0和1)的矩阵类型表示
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)model = Sequential()
model.add(Conv2D(6, kernel_size=(5, 5), activation='tanh', input_shape=input_shape))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Conv2D(16, kernel_size=(5, 5), activation='tanh'))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(10, activation='softmax'))# 模型编译
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# 训练
model.fit(X_train, y_train, batch_size=128, epochs=10)# 评估模型
score = model.evaluate(X_test, y_test)
print('acc', score[1])
plot_model(model, to_file='model.png', show_shapes=True)
model.save("lenet5.h5")

    运行代码,打印模型参数信息如下所示:

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 28, 28, 6)         156       average_pooling2d (AverageP  (None, 14, 14, 6)        0         ooling2D)                                                       conv2d_1 (Conv2D)           (None, 10, 10, 16)        2416      average_pooling2d_1 (Averag  (None, 5, 5, 16)         0         ePooling2D)                                                     flatten (Flatten)           (None, 400)               0         dense (Dense)               (None, 120)               48120     dense_1 (Dense)             (None, 84)                10164     dense_2 (Dense)             (None, 10)                850       =================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0

    训练和测试过程如下:

Epoch 1/10
2023-08-30 22:04:32.166768: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8800
469/469 [==============================] - 8s 12ms/step - loss: 0.3042 - accuracy: 0.9110
Epoch 2/10
469/469 [==============================] - 5s 10ms/step - loss: 0.1113 - accuracy: 0.9664
Epoch 3/10
469/469 [==============================] - 5s 10ms/step - loss: 0.0709 - accuracy: 0.9784
Epoch 4/10
469/469 [==============================] - 5s 11ms/step - loss: 0.0530 - accuracy: 0.9843
Epoch 5/10
469/469 [==============================] - 5s 11ms/step - loss: 0.0410 - accuracy: 0.9875
Epoch 6/10
469/469 [==============================] - 5s 11ms/step - loss: 0.0329 - accuracy: 0.9898
Epoch 7/10
469/469 [==============================] - 6s 12ms/step - loss: 0.0283 - accuracy: 0.9910
Epoch 8/10
469/469 [==============================] - 6s 12ms/step - loss: 0.0228 - accuracy: 0.9928
Epoch 9/10
469/469 [==============================] - 6s 12ms/step - loss: 0.0193 - accuracy: 0.9939
Epoch 10/10
469/469 [==============================] - 6s 13ms/step - loss: 0.0159 - accuracy: 0.9949
313/313 [==============================] - 2s 4ms/step - loss: 0.0406 - accuracy: 0.9871
acc 0.9871000051498413

   测试准确率高达98.7%。

    代码最后,我们还通过plot_model保存了模型图片:

 

    另外,为了用来预测,我们还保存了模型文件lenet5.h5。 

    预测,还是祖传代码,改动了图片形状32*32像素,因为预测图片还是28*28像素,黑底白字:

import keras
import numpy as np
import cv2
from keras.models import load_modelmodel = load_model("lenet5.h5")def predict(img_path):img = cv2.imread(img_path, 0)img = cv2.resize(img, (32, 32))img = img.astype("float32") / 255  # 0 1img = img.reshape(1, 32, 32, 1)  # 32 * 32 -> (1,32,32,1)label = model.predict(img)label = np.argmax(label, axis=1)print('{} -> {}'.format(img_path, label[0]))if __name__ == '__main__':for _ in range(10):predict("number_images/b_{}.png".format(_))

    实验结果:

 

    这个结果,其实并不意外,也不是因为这个模型非常牛逼,预测高达100%,其实这里预测的图片只有10个,数量不多,很多图片都是自己测试过的,所以看着好像很厉害。 

    这里面有很多可变的地方,第一个就是前面说过的输入大小的问题,如果我们改动32*32到28*28,那么我们在使用mnist数据集的时候,也不用修改形状。但是这个就与这个模型有出入,虽然最后也能运行出很高的测试准确率和预测准确率。再一个可变的地方就是卷积层的激活函数,这里使用的是tanh,其实使用relu也没问题,再一个就是使用MaxPool2D作为下采样也是可以的。在模型编译的时候,我们使用的是rmsprop优化器,其实adam也行。

     LeNet5模型其实非常适合用来做数字识别,但是数字识别没有合适的训练数据集,而cifar10这个数据集就是32*32的,所以,从编码角度来说,这个模型最适合的还是图像分类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/112450.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中欧财富:分布式数据库的应用历程和 TiDB 7.1 新特性探索

作者:张政俊 中欧财富数据库负责人 中欧财富是中欧基金控股的销售子公司,旗下 APP 实现业内基金品种全覆盖,提供基金交易、大数据选基、智慧定投、理财师咨询等投资工具及服务。中欧财富致力为投资者及合作伙伴提供一站式互联网财富管理解决方…

js优雅的统计字符串字符出现次数

题目如下 统计一串字符串中每个字符出现的频率 示例字符串 let str asdfasqwerqwrdfafafasdfopasdfopckpasdfassfd小白写法 let str asdfasqwerqwrdfafafasdfopasdfopckpasdfassfdlet result {}; for (let i 0; i < str.length; i) {if (result[str[i]]) {result[str[…

论文笔记: One Fits All:Power General Time Series Analysis by Pretrained LM

1 intro 时间序列领域预训练模型/foundation 模型的研究还不是很多 主要挑战是缺乏大量的数据来训练用于时间序列分析的基础模型——>论文利用预训练的语言模型进行通用的时间序列分析 为各种时间序列任务提供了一个统一的框架 论文还调查了为什么从语言领域预训练的Transf…

【ag-grid-vue】column

网格中的每一列都使用列定义(ColDef)来定义。列根据在网格选项中指定的列定义的顺序在网格中定位。 列定义 下面的例子展示了一个定义了3列的简单网格: <template><ag-grid-vuestyle"height: 300px; width: 1000px"class"ag-theme-balham":colum…

get√接口自动化核心知识点浓缩,为面试加分

日常接触到的接口自动化从实际目标可以划分为两大类&#xff1a; 1、为模拟测试数据而开展的接口自动化 这种接口自动化大多是单次执行&#xff0c;目的很明确是为了功能测试创造测试数据&#xff0c;节约人工造数据的时间和人工成本&#xff0c;提高功能测试人员的测试效率。…

chain of thought (思维链, cot)

定义 思维链 (Chain-of-thought&#xff0c;CoT) 的概念是在 Google 的论文 "Chain-of-Thought Prompting Elicits Reasoning in Large Language Models" 中被首次提出。思维链&#xff08;CoT&#xff09;是一种改进的提示策略&#xff0c;用于提高 LLM 在复杂推理…

【UE5】给模型指定面添加自定义材质

实现步骤 1. 首先我们向UE中导入一个简单的模型&#xff0c;可以看到目前该模型的材质插槽只有一个&#xff0c;当我们修改材质时会使得模型整体的材质全部改变&#xff0c;如果我们只想改变模型的某些面的材质就需要继续做后续操作。 2. 选择建模模式 3. 在模式工具栏中点击…

Linux学习之Ubuntu 20使用systemd管理OpenResty服务

sudo cat /etc/issue可以看到操作系统的版本是Ubuntu 20.04.4 LTS&#xff0c;sudo lsb_release -r可以看到版本是20.04&#xff0c;sudo uname -r可以看到内核版本是5.5.19&#xff0c;sudo make -v可以看到版本是GNU Make 4.2.1。 需要先参考我的博客《Linux学习之Ubuntu 2…

SpringBoot Mybatis 多数据源 MySQL+Oracle

一、背景 在SpringBoot Mybatis 项目中&#xff0c;需要连接 多个数据源&#xff0c;连接多个数据库&#xff0c;需要连接一个MySQL数据库和一个Oracle数据库 二、依赖 pom.xml <dependencies><dependency><groupId>org.springframework.boot</groupId&…

【Golang】go条件编译

交叉编译只是为了能在一个平台上编译出其他平台可运行的程序&#xff0c;Go 作为一个跨平台的语言&#xff0c;它提供的类库势必也是跨平台的&#xff0c;比如说程序的系统调用相关的功能&#xff0c;能根据所处环境选择对应的源码进行编译。让编译器只对满足条件的代码进行编译…

【Linux】centos8安装cmake3.27.4

第一步&#xff0c;去官网下安装包&#xff0c;一定不要下错了 下好了之后&#xff0c;用ftp软件传到云服务器或者虚拟机上&#xff0c;我用的是centos8系统&#xff0c;安装之前先准备好这些依赖项 yum install -y gcc gcc-c make automake yum install -y openssl openssl-…

多线程应用——单例模式

单例模式 文章目录 单例模式一.什么是单例模式二.如何实现1.口头实现2.利用语法特性 三.实现方式&#xff08;饿汉式懒汉式&#xff09;1.饿汉式2.懒汉式3.线程安全的单例模式4.双重检查锁5.禁止指令重排序 一.什么是单例模式 单例模式&#xff08;Singleton Pattern&#xff…

LLM本地知识库问答系统(二):如何正确使用LlamaIndex索引

推荐阅读列表&#xff1a; LLM本地知识库问答系统&#xff08;一&#xff09;&#xff1a;使用LangChain和LlamaIndex从零构建PDF聊天机器人指南 上一篇文章我们介绍了使用LlamaIndex构建PDF聊天机器人&#xff0c;本文将介绍一下LlamaIndex的基本概念和原理。 LlamaIndex简介…

视频分割合并工具说明

使用说明书&#xff1a;视频分割合并工具 欢迎使用视频生成工具&#xff01;本工具旨在帮助您将视频文件按照指定的规则分割并合并&#xff0c;以生成您所需的视频。 本程序还自带提高分辨率1920:1080&#xff0c;以及增加10db声音的功能 软件下载地址 https://github.com/c…

FPGA原理与结构——时钟IP核原理学习

一、前言 在之前的文章中&#xff0c;我们介绍了FPGA的时钟结构 FPGA原理与结构——时钟资源https://blog.csdn.net/apple_53311083/article/details/132307564?spm1001.2014.3001.5502 在本文中我们将学习xilinx系列的FPGA所提供的时钟IP核&#xff0c;来帮助我们进一…

TCP/IP五层模型、封装和分用

1.网络通信基础2.协议分层OSI七层协议模型TCP/IP五层/四层协议模型【重点】 3. 封装&分用 1.网络通信基础 IP地址&#xff1a;表示计算机的位置&#xff0c;分源IP和目标IP&#xff1b;举个例子&#xff1a;买快递&#xff0c;商家从上海发货&#xff0c;上海就是源IP&…

理虚实一体化全栈全场景云计算应用实训室解决方案

一、 云计算应用统概述 云计算应用系统是指基于云计算技术构建的应用系统&#xff0c;它将软件、数据、计算和存储资源部署在云服务器上&#xff0c;通过网络根据应用按照一定策略为用户提供相关服务。云计算应用系统广泛应用于各个领域&#xff0c;包括但不限于金融、教育、政…

Windows 系统彻底卸载 SQL Server 通用方法

Windows 系统彻底卸载 SQL Server 通用方法 无论什么时候&#xff0c;SQL Server 的安装和卸载都是一件让我们头疼的事情。因为不管是 SQL Server 还是 MySQL 的数据库&#xff0c;当我们在使用数据库时因为未知原因出现问题&#xff0c;想要卸载重装时&#xff0c;如果数据库…

零基础如何使用IDEA启动前后端分离中的前端项目(Vue)?

一、在IDEA中配置vue插件 点击File-->Settings-->Plugins-->搜索vue.js插件进行安装&#xff0c;下面的图中我已经安装好了 二、搭建node.js环境 安装node.js 可以去官网下载&#xff1a;安装过程就很简单&#xff0c;直接下一步就行 测试是否安装成功&#xff1a;要…

[JDK8下的HashMap类应用及源码分析] 数据结构、哈希碰撞、链表变红黑树

系列文章目录 [Java基础] StringBuffer 和 StringBuilder 类应用及源码分析 [Java基础] 数组应用及源码分析 [Java基础] String&#xff0c;分析内存地址&#xff0c;源码 [JDK8环境下的HashMap类应用及源码分析] 第一篇 空构造函数初始化 [JDK8环境下的HashMap类应用及源码分…