LeNet对MNIST 数据集中的图像进行分类--keras实现

我们将训练一个卷积神经网络来对 MNIST 数据库中的图像进行分类,可以与前面所提到的CNN实现对比CNN对 MNIST 数据库中的图像进行分类-CSDN博客

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图像
from keras.datasets import mnist# 使用 Keras 导入预洗牌 MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

将前六个训练图像可视化 

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])ax.imshow(X_train[i], cmap='gray')ax.set_title(str(y_train[i]))

 查看图像的更多细节

def visualize_input(img, ax):ax.imshow(img, cmap='gray')width, height = img.shapethresh = img.max()/2.5for x in range(width):for y in range(height):ax.annotate(str(round(img[x][y],2)), xy=(y,x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y]<thresh else 'black')fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# normalize the data to accelerate learning
mean = np.mean(X_train)
std = np.std(X_train)
X_train = (X_train-mean)/(std+1e-7)
X_test = (X_test-mean)/(std+1e-7)print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import np_utilsnum_classes = 10 
# print first ten (integer-valued) training labels
print('Integer-valued labels:')
print(y_train[:10])# one-hot encode the labels
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)# print first ten (one-hot) training labels
print('One-hot labels:')
print(y_train[:10])

重塑数据以适应我们的 CNN(和 input_shape)

# input image dimensions 28x28 pixel images. 
img_rows, img_cols = 28, 28X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)print('image input shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

论文地址:lecun-01a.pdf

要在 Keras 中实现 LeNet-5,请阅读原始论文并从第 6、7 和 8 页中提取架构信息。以下是构建 LeNet-5 网络的主要启示:

  • 每个卷积层的滤波器数量:从图中(以及论文中的定义)可以看出,每个卷积层的深度(滤波器数量)如下:C1 = 6、C3 = 16、C5 = 120 层。
  • 每个 CONV 层的内核大小:根据论文,内核大小 = 5 x 5
  • 每个卷积层之后都会添加一个子采样层(POOL)。每个单元的感受野是一个 2 x 2 的区域(即 pool_size = 2)。请注意,LeNet-5 创建者使用的是平均池化,它计算的是输入的平均值,而不是我们在早期项目中使用的最大池化层,后者传递的是输入的最大值。如果您有兴趣了解两者的区别,可以同时尝试。在本实验中,我们将采用论文架构。
  • 激活函数:LeNet-5 的创建者为隐藏层使用了 tanh 激活函数,因为对称函数被认为比 sigmoid 函数收敛更快。一般来说,我们强烈建议您为网络中的每个卷积层添加 ReLU 激活函数。

需要记住的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除了网络中的最后一层,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最后一层应该是具有软最大激活函数的密集(FC)层。最终层的节点数应等于数据集中的类别总数。
from keras.models import Sequential
from keras.layers import Conv2D, AveragePooling2D, Flatten, Dense
#Instantiate an empty model
model = Sequential()# C1 Convolutional Layer
model.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding='same'))# S2 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))# C3 Convolutional Layer
model.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))# S4 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))# C5 Fully Connected Convolutional Layer
model.add(Conv2D(120, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))#Flatten the CNN output so that we can connect it with fully connected layers
model.add(Flatten())# FC6 Fully Connected Layer
model.add(Dense(84, activation='tanh'))# Output Layer with softmax activation
model.add(Dense(10, activation='softmax'))# print the model summary
model.summary()

编译模型

我们将使用亚当优化器

# the loss function is categorical cross entropy since we have multiple classes (10) # compile the model by defining the loss function, optimizer, and performance metric
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

训练模型

LeCun 和他的团队采用了计划衰减学习法,学习率的值按照以下时间表递减:前两个历元为 0.0005,接下来的三个历元为 0.0002,接下来的四个历元为 0.00005,之后为 0.00001。在论文中,作者对其网络进行了 20 个历元的训练。

from keras.callbacks import ModelCheckpoint, LearningRateScheduler# set the learning rate schedule as created in the original paper
def lr_schedule(epoch):if epoch <= 2:     lr = 5e-4elif epoch > 2 and epoch <= 5:lr = 2e-4elif epoch > 5 and epoch <= 9:lr = 5e-5else: lr = 1e-5return lrlr_scheduler = LearningRateScheduler(lr_schedule)# set the checkpointer
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)# train the model
hist = model.fit(X_train, y_train, batch_size=32, epochs=20,validation_data=(X_test, y_test), callbacks=[checkpointer, lr_scheduler], verbose=2, shuffle=True)

在验证集上加载分类准确率最高的模型

# load the weights that yielded the best validation accuracy
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率

# evaluate test accuracy
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]# print test accuracy
print('Test accuracy: %.4f%%' % accuracy)

评估模型

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['accuracy'], 'o-')
ax.plot([None] + hist.history['val_accuracy'], 'x-')
# 绘制图例并自动使用最佳位置: loc = 0。
ax.legend(['Train acc', 'Validation acc'], loc = 0)
ax.set_title('Training/Validation acc per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('acc')
plt.show()

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['loss'], 'o-')
ax.plot([None] + hist.history['val_loss'], 'x-')# Plot legend and use the best location automatically: loc = 0.
ax.legend(['Train loss', "Val loss"], loc = 0)
ax.set_title('Training/Validation Loss per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/210433.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT 中 QDateTime::currentDateTime() 输出格式备查

基础 QDateTime::currentDateTime() //当前的日期和时间。 QDateTime::toString() //以特定的格式输出时间&#xff0c;格式 yyyy: 年份&#xff08;4位数&#xff09; MM: 月份&#xff08;两位数&#xff0c;07表示七月&#xff09; dd: 日期&#xff08;两位数&#xff0c…

【unity3D】unity中如何查找和获取游戏物体

&#x1f497; 未来的游戏开发程序媛&#xff0c;现在的努力学习菜鸡 &#x1f4a6;本专栏是我关于游戏开发的学习笔记 &#x1f236;本篇是unity中游戏物体的查找与获取 这里写自定义目录标题 获取当前物体的基本属性查找其它物体- 通过名称查找其它物体- 通过标签查找- 通过类…

计算机网络入侵检测技术研究

摘 要 随着网络技术的发展&#xff0c;全球信息化的步伐越来越快&#xff0c;网络信息系统己成为一个单位、一个部门、一个行业&#xff0c;甚至成为一个关乎国家国计民生的基础设施&#xff0c;团此&#xff0c;网络安全就成为国防安全的重要组成部分&#xff0c;入侵检测技术…

C++系统思维导图

自己在复盘C的时候做了 一些笔记&#xff0c;用思维导图形式记录下来的一些概念&#xff0c;多线程的内容较少&#xff0c;主要是派生和继承&#xff0c;以及虚函数和多态内容多一些&#xff0c;其他也有一些零碎的小知识点&#xff0c;和大家分享一下。有任何问题请留言。原图…

生鲜蔬果展示预约小程序作用是什么

线下生鲜蔬果店非常多&#xff0c;对商家来说主要以同城生意为主&#xff0c;而在互联网电商的发展下&#xff0c;更多的商家会选择搭建私域商城进行多渠道的销售卖货和拓展&#xff0c;当然除了直接卖货外&#xff0c;还有产品纯展示或预约订购等需求。 但无论哪种模式&#…

ubuntu离线安装包下载和安装

一、确认本机ubuntu的发行版本 方法1: rootac810:/home/ac810/alex# lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 20.04.6 LTS Release: 20.04 Codename: focal 方法2: rootac810:/home/ac810/alex# cat /…

uniapp小程序分包页面引入wxcomponents(vue.config.js、copy-webpack-plugin)

实例&#xff1a;小程序添加一个源生小程序插件&#xff0c;按照uniapp官方的说明&#xff0c;要放在wxcomponents。后来发现小程序超2m上传不了。 正常的编译情况 会被编译到主包下 思路&#xff1a;把wxcomponents给编译到分包sub_package下 用uniapp的vue.config.js自定义…

电脑如何录音?适合初学者的详细教程

“电脑怎么录音呀&#xff1f;参加了一个学校举办的短视频大赛&#xff0c;视频拍摄都很顺利&#xff0c;音乐却出了问题&#xff0c;朋友说可以用电脑录制一段音乐应付一下&#xff0c;可是我不会操作&#xff0c;有哪位大佬教教我&#xff01;” 声音是一种强大的媒介&#…

智能电表需要安装电池吗?

智能电表是一种新型的电力测量设备&#xff0c;它使用先进的技术和功能来监控、记录和报告能源消耗情况。对于智能电表是否需要安装电池这个问题&#xff0c;答案是有可能需要&#xff0c;但并不是所有智能电表都需要安装电池。 首先&#xff0c;我们需要了解智能电表是如何工作…

从0开始使用Maven

文章目录 一.Maven的介绍即相关概念1.为什么使用Maven/Maven的作用2.Maven的坐标 二.Maven的安装三.IDEA编译器配置Maven环境1.在IDEA的单个工程中配置Maven环境2.方式2&#xff1a;配置Maven全局参数 四.IDEA编译器创建Maven项目五.IDEA中的Maven项目结构六.IDEA编译器导入Mav…

C语言能判断一个变量是int还是float吗?

C语言能判断一个变量是int还是float吗&#xff1f; 在开始前我有一些资料&#xff0c;是我根据自己从业十年经验&#xff0c;熬夜搞了几个通宵&#xff0c;精心整理了一份「C语言从专业入门到高级教程工具包」&#xff0c;点个关注&#xff0c;全部无偿共享给大家&#xff01;&…

数据库表的管理

表的基本概念 表是包含数据库中所有数据的数据库对象。数据在表中的组织方式与在电子表格中相似&#xff0c;都是 按行和列的格式组织的。每行代表一条唯一的记录&#xff0c;每列代表记录中的一个字段。例如&#xff0c;在包含公 司员工信息的表中&#xff0c;每行代表一名员工…

探索什么样的导线,适合做433的天线

​​​​​​一、理论基础 (3 封私信 / 18 条消息) 为什么天线的材质会影响接收电磁波的效果&#xff1f; - 知乎 (zhihu.com) 电感基础3——RLC电路&#xff0c;帮助你轻松理解“阻抗”的概念 - 知乎 (zhihu.com) 一文读懂介电性能---介电常数 - 知乎 (zhihu.com) 433MHz 至…

SpringBoot 注入RedisTemplat 启动报错

需求 因为需要限制部门内多个人员同一时间操作同一批客户的需求&#xff0c;考虑下决定用Redis滑动窗口实现自过期以及并发校验。 问题 新建了个Redis工具类封装RedisTemplat 操作&#xff0c;到启动时却发现无法正常启动&#xff0c;报错注入错误。 The injection point has…

利用 EC2 和 S3 免费搭建私人网盘

网盘是一种在线存储服务&#xff0c;提供文件存储&#xff0c;访问&#xff0c;备份&#xff0c;贡献等功能&#xff0c;是我们日常中不可或缺的一种服务。 &#x1f4bb;创建实例 控制台搜索EC2 点击启动EC2 选择AMI 选择可免费试用的 g代表采用了Graviton2芯片。 配置存储 配…

【带头学C++】----- 九、类和对象 ---- 9.5 初始化列表

目录 9.5 初始化列表 9.5.1 对象成员 代码&#xff1a; 9.5.2 初始化列表 9.5 初始化列表 9.5.1 对象成员 在类中定义的数据成员一般都是基本的数据类型。但是类中的成员也可以是对象&#xff0c;叫做对象成员。 先调用对象成员的构造函数&#xff0c;再调用本身的构造函数…

NIO--07--Java lO模型详解

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 何为 IO?先从计算机结构的角度来解读一下I/o.再从应用程序的角度来解读一下I/O 阻塞/非阻塞/同步/异步IO阻塞IO非阻塞IO异步IO举例 Java中3种常见的IO模型BIO (Blo…

C++日常遇到的一些坑的总结

一、const 相关 C中const的不同位置的用法 const 修饰符用法总结 二、函数形参没有变量名 三、指针偏移问题 笔记&#xff1a; 包含来自C标准库的头文件&#xff0c;用#inlcude<xxx>&#xff0c;包含不来自C标准库的头文件&#xff0c;用#include"xxx"最…

Python setup.py 编写

from setuptools import setup, find_packages import atexit import shutilname "sdk"def rm_temp():"""删除打包过程中产生的中间文件:return:"""shutil.rmtree(build)shutil.rmtree({}.egg-info.format(name))# 注册 rm_temp atex…

Hadoop学习笔记(HDP)-Part.18 安装Flink

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …