在 Python 中构建卷积神经网络; 从 0 到 9 的手绘数字的灰度图像预测数字

一、说明

        为了预测从0到9的数字,我选择了一个基于著名的Kaggle的MNIST数据集的数据集。数据集包含从 <0> 到 <9> 的手绘图数字的灰度图像。在本文中,我将根据像素数据(即数值数据)和卷积神经网络预测数字。

二、 卷积神经网络

        卷积神经网络,也称为 CNN 或 ConvNet,是一种人工神经网络,迄今为止最常用于分析计算机视觉任务的图像。

        尽管图像分析是CNNS最广泛的用途,但它们也可用于其他数据分析或分类。让我们开始吧!

        一般来说,我们可以将CNN视为一种人工神经网络,它具有某种类型的专业化,能够挑选或检测模式。这种模式检测使CNN在图像分析中如此有用。

        但是,如果CNN只是一个人工神经网络,那么它与标准的多层感知器或MLP有什么区别呢?

        CNN有称为卷积层的隐藏层,这些层是构成CNN的,嗯......一个美国有线电视新闻网!

CNN具有称为卷积层的层。

CNN可以,而且通常也有其他非卷积层,但CNN的基础是卷积层。

好的,那么这些卷积层是做什么的呢?

就像任何其他层一样,卷积层接收输入,以某种方式转换输入,然后将转换后的输入输出到下一层。卷积层的输入称为输入通道,输出称为输出通道。

对于卷积层,发生的转换称为卷积操作。无论如何,这是深度学习社区使用的术语。在数学上,卷积层执行的卷积运算实际上称为互相关。

如前所述,卷积神经网络能够检测图像中的模式。

让我们扩展一下我们的意思 当我们说过滤器能够检测模式时。想想任何单个图像中可能发生了什么。多个边缘、形状、纹理、对象等。这就是我们所说的模式

  • 边缘
  • 形状
  • 纹理
  • 曲线
  • 对象
  • 颜色

滤波器可以在图像中检测到的一种图案是边缘,因此该滤波器称为边缘检测器

除了边缘之外,某些过滤器可能会检测到角落。有些人可能会检测到圆圈。其他,正方形。现在这些简单的几何滤波器 就是我们在卷积神经网络开始时看到的。

网络越深入,过滤器就越复杂。在后面的图层中,我们的过滤器可能能够检测特定的物体,而不是边缘和简单的形状,如眼睛、耳朵、头发或毛皮、羽毛、鳞片和喙。

在更深的层中,过滤器能够检测到更复杂的物体,如完整的狗、猫、蜥蜴和鸟类。

三、 数据理解

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import itertools
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Dropout, Flatten, MaxPooling2D
import osdf = pd.read_csv("MNIST_ROI.csv")

3.1 探索性分析

df.shape

(59999, 785)

数据集包括 59,999 条记录和 785 个字段。每条记录代表手绘图数字的灰度图像,介于 0 到 9 之间。

第一列称为“结果”,是用户绘制的数字。

其余列包含关联图像的像素值。每个灰度图像的高度为 28 像素,宽度为 28 像素,总共 784 像素。

df.head()

每个像素都有一个与之关联的像素值,指示该像素的明暗度,数字越大意味着越亮。此像素值是介于 0(黑色)和 255(白色)之间的整数(包括 <>(黑色)和 <>(白色)。

df.tail()

df.info()

df.describe()

3.2 数据分析

让我们检查数据集中每个数字有多少张图像

dig = [0,1,2,3,4,5,6,7,8,9]
num = []
for i in range(0,10):num.append(len(df[df['Result']==i]))d = {'Digit': dig, 'Count': num}
df1 = pd.DataFrame(data=d)
df1

import matplotlib.pyplot as plt
import seaborn as sns
sns.barplot(x = “Count”, y = “Digit”, data = df2, orient=’h’)
plt.show()

让我们看看数据集中的哪些行中有数字“3”的图像

df[df[‘Result’]==3].head()

让我们打印第 6 行的图像

pic = df[6:7].values.reshape(785)[1:].reshape(28,28)
plt.imshow(pic,cmap='gray')

让我们看看数据集中的哪些行中有数字“5”的图像

df[df[‘Result’]==5].head()

让我们打印第 10 行的图像

pic = df[10:11].values.reshape(785)[1:].reshape(28,28)
plt.imshow(pic,cmap=’gray’)

四、数据准备

 X = df.drop(['Result'],axis=1)X.head() 

 y = df.Resulty.head() 

import sklearn.model_selection as skmodelX_train, X_test, y_train, y_test = skmodel.train_test_split(X, y, test_size=0.33, random_state=42)print("length of all data is ","{:,}".format(len(X)))
print("length of training set is","{:,}".format(len(X_train)))
print("length of test set is","{:,}".format(len(X_test))) 

X_train.head()

y_train.head()

让我们将训练集和测试集从 pandas.core.frame.DataFrame 转换为 numpy.ndarray

x_train = np.array(X_train)
y_train = np.array(y_train)
x_test = np.array(X_test)
y_test = np.array(y_test)len(X_train)

40199

让我们画一个介于 0 到 40199 之间的数字

i = random.randint(0,(len(X_train)))
i

34944

现在,让我们打印训练集中第 34944 行的图像结果

print(y_train[i])

3

让我们打印训练集中第 34944 行的图像

pic = X_train.iloc[i].values.reshape(28,28)plt.imshow(pic, cmap=’Greys’)

x_train.shape

(40199, 784)

让我们将数组重塑为 4 个 dimnsions,以便它可以与 Keras API 一起使用

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)

让我们确保值是浮点数的,以便我们可以在除法后获得小数点

x_train = x_train.astype('float32')
x_test = x_test.astype('float32') 

现在,让我们通过将 RGB 代码除以最大 RGB 值来规范化 RGB 代码

x_train /= 255
x_test /= 255print('x_train shape:', x_train.shape)
print('Number of images in x_train', x_train.shape[0])
print('Number of images in x_test', x_test.shape[0])

五、建模

让我们使用顺序模型构建一个 CNN 并添加层:

model = Sequential()
model.add(Conv2D(28, kernel_size=(3,3), input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten()) # Flattening the 2D arrays for fully connected layers
model.add(Dense(128, activation=tf.nn.relu))
model.add(Dropout(0.2))
model.add(Dense(10,activation=tf.nn.softmax))

让我们编译我们的CNN

model.compile(optimizer=’adam’, loss=’sparse_categorical_crossentropy’,  metrics=[‘accuracy’])

现在,让我们训练我们的CNN

model.fit(x=x_train,y=y_train, epochs=10)

训练集准确率:99.37%

model.evaluate(x_test, y_test)

测试装置准确率:98.22%

训练集的准确率为 99.37%,而测试集的准确率为 98.22%。这表明卷积神经网络(CNN)很好地推广到新数据,而不是过度拟合。

六、评估

len(X_test)

19800

让我们画一个介于 0 到 19800 之间的数字

j = random.randint(0,(len(X_test)))
j

11092

现在,让我们对测试集中第 11092 行的图像结果进行预测

pred = model.predict(x_test[j].reshape(1, 28, 28, 1))print(pred.argmax())

6

让我们打印测试集中第 11092 行的图像

pic1 = X_test.iloc[j].values.reshape(28,28)
plt.imshow(pic1, cmap='Greys')

y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred,axis=1)
y_pred.shape

(19800, )

6.1 混淆矩阵

import sklearn.metrics as skmetcm = skmet.confusion_matrix(y_true=y_test, y_pred=y_pred)def plot_confusion_matrix(cm, classes,normalize=False,title=’Confusion matrix’,cmap=plt.cm.Blues):“””This function prints and plots the confusion matrix.Normalization can be applied by setting `normalize=True`.“””plt.imshow(cm, interpolation=’nearest’, cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)print(‘Confusion matrix, without normalization’)
print(cm)
thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment=”center”,color=”white” if cm[i, j] > thresh else “black”)
plt.tight_layout()
plt.ylabel(‘True label’)
plt.xlabel(‘Predicted label’)cm_plot_labels = [‘0’,’1',’2',’3',’4',’5',’6',’7',’8',’9']plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title=’Confusion Matrix’)

print(“\033[1m The result is telling us that we have: “,(cm[0,0]+cm[1,1]+cm[2,2]+cm[3,3]+cm[4,4]+cm[5,5]+cm[6,6]+cm[7,7]+cm[8,8]+cm[9,9]),”correct predictions.”)
print(“\033[1m The result is telling us that we have: “,(cm.sum()-(cm[0,0]+cm[1,1]+cm[2,2]+cm[3,3]+cm[4,4]+cm[5,5]+cm[6,6]+cm[7,7]+cm[8,8]+cm[9,9])),”incorrect predictions.”)
print(“\033[1m We have total predictions of: “,(cm.sum()))

6.2 计算精度、召回率、f 分数和支持

        引用Scikit Learn的话:

        精度是比率 tp / (tp + fp),其中 tp 是真阳性数,fp 是误报数。精度直观地是分类器在样本为阴性时不将其标记为阳性的能力。

        召回率是比率 tp / (tp + fn),其中 tp 是真阳性的数量,fn 是假阴性的数量。召回率直观地是分类器找到所有阳性样本的能力。

        f1 分数可以解释为精度和召回率的加权调和平均值,其中 f1 分数在 1 达到其最佳值,在 0 时达到最差分数。

        f1 分数将召回率的权重比精度高 1.0 倍,这意味着召回率和精度同样重要。

        支持是每个类在y_test中的出现次数。

print(skmet.classification_report(y_test, y_pred))

七、部署

        因此,我们的卷积神经网络(CNN)模型是一个很好的模型,可以从0到9的手绘数字的灰度图像中预测数字。现在,我们如何从新的灰度图像中预测数字?

len(X_test)

19800

        让我们画一个介于 0 到 19800 之间的数字

k = random.randint(0,(len(X_test)))
k

766

        让我们使用我们的模型预测来自 pred1 的数字

pred1 = model.predict(x_train[k].reshape(1, 28, 28, 1))
print(pred1.argmax())

7

        我们的模型说我们画了一个数字“7”的图像。因此,让我们打印此图像以查看我们的模型是否正确

pic2 = X_train.iloc[k].values.reshape(28,28)
plt.imshow(pic2, cmap='Greys')

是的!我们的模型是正确的。

八、总结

        卷积神经网络(ConvNet/CNN)是一种深度学习算法,可以接收输入图像,为图像中的各个方面/对象分配重要性(可学习的权重和偏差),并能够区分彼此。

        卷积神经网络的架构类似于人脑中神经元的连接模式,并受到视觉皮层组织的启发。

        单个神经元仅在称为感受野的视野的受限区域中对刺激做出反应。

        此类字段的集合重叠以覆盖整个视觉区域。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/118748.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智能合约安全分析,针对 ERC777 任意调用合约 Hook 攻击

智能合约安全分析&#xff0c;针对 ERC777 任意调用合约 Hook 攻击 Safful发现了一个有趣的错误&#xff0c;有可能成为一些 DeFi 项目的攻击媒介。这个错误尤其与著名的 ERC777 代币标准有关。此外&#xff0c;它不仅仅是众所周知的黑客中常见的简单的重入问题。 这篇文章对 …

2、Nginx 安装

文章目录 2、Nginx 安装2.1 官网下载2.2 安装 nginx2.2.1 第一步2.2.2 第二步2.2.3 第三步&#xff0c;安装 nginx2.2.4 第四步&#xff0c;修改防火漆规则 【尚硅谷】尚硅谷Nginx教程由浅入深 志不强者智不达&#xff1b;言不信者行不果。 2、Nginx 安装 2.1 官网下载 nginx…

软件测试面试怎样介绍自己的测试项目?会问到什么程度?

想知道面试时该怎样介绍测试项目&#xff1f;会问到什么程度&#xff1f;那就需要换位思考&#xff0c;思考HR在这个环节想知道什么。 HR在该环节普遍想获得的情报主要是下面这2个方面&#xff1a; 1&#xff09;应聘者的具体经验和技术能力&#xff0c; 2&#xff09;应聘者的…

Python实战之数据表提取和下载自动化

在网络爬虫领域&#xff0c;动态渲染类型页面的数据提取和下载自动化是一个常见的挑战。本文将介绍如何利用Pyppeteer库完成这一任务&#xff0c;帮助您轻松地提取动态渲染页面中的数据表并实现下载自动化。 一、环境准备 首先&#xff0c;确保您已经安装了Python环境。接下来…

Android.mk开发模板

今天简单写了一个 Android.mk 的示例模板&#xff0c;供初学者参考。 本模板主要给大家示例 Android NDK 开发中的如下几个问题&#xff1a; 如何自动添加需要编译的源文件列表如何添加第三方静态库、动态库的依赖如何构造一个完整的NDK工程框架 假设我们的项目依赖 libmath.…

【Go 基础篇】Go语言结构体基本使用

在Go语言中&#xff0c;结构体是一种重要的数据类型&#xff0c;用于定义和组织一组不同类型的数据字段。结构体允许开发者创建自定义的复合数据类型&#xff0c;类似于其他编程语言中的类。本文将深入探讨Go语言中结构体的定义、初始化、嵌套、方法以及与其他语言的对比&#…

小赢科技,寻找金融科技核心价

如果说金融是经济的晴雨表&#xff0c;是通过改善供给质量以提高经济质量的切入口&#xff0c;那么金融科技公司&#xff0c;就是这一切行动的推手。上半年&#xff0c;社会经济活跃程度提高背后&#xff0c;金融科技公司既是奉献者&#xff0c;也是受益者。 8月29日&#xff0…

postgresql并行查询(高级特性)

######################## 并行查询 postgresql和Oracle一样支持并行查询的,比如select、update、delete大事无开启并行功能后,能够利用多核cpu,从而充分发挥硬件性能,提升大事物的处理效率。 pg在9.6的版本之前是不支持的并行查询的,从9.6开始支持并行查询,但是功能非常…

go学习part21(3)redis连接池

连接池 1.介绍 每次使用数据就就建立链接再关闭可以&#xff0c;但是如果有大量客户端频繁请求连接&#xff0c;大量创建连接和关闭会非常耗费资源。 所以就建立一个连接池&#xff0c;里面存放几个不关闭的连接&#xff0c;谁要用就分配给谁。 说明:通过Golang 对 Redis操…

WebGPT VS WebGPU

推荐&#xff1a;使用 NSDT编辑器 快速搭建3D应用场景 随着WebGPU的引入&#xff0c;Web开发发生了有趣的转变&#xff0c;WebGPU是一种新的API&#xff0c;允许Web应用程序直接访问设备的图形处理单元&#xff08;GPU&#xff09;。这种发展意义重大&#xff0c;因为 GPU 擅长…

sublime编辑latex 出现参考文献无法编译报错:citation “...” undefined

问题描述 使用sublime编译latex文件时&#xff0c;参考文献按照常规的方式放好&#xff0c;ctrl B 编译的时候&#xff0c;显示找不到参考文献&#xff0c;编译出的pdf文件也没有references&#xff1a; 但是把文件放到overleaf上就可以直接编译出来&#xff0c;说明是本地编…

快速为RPG辅助工具MTool增加更多快捷键(一键保存等)

起源&#xff1a;MTool是个好工具&#xff0c;本身固然好用&#xff0c;但是它本身的快捷键功能很少&#xff0c;虽然内置了一个录制工具&#xff0c;但是一个个的录&#xff0c;又麻烦&#xff0c;一般人也难以掌握 本文用快速方法增加更多快捷键&#xff0c;可以做到一键保存…

c++ qt--线程(一)(第八部分)

c qt–线程&#xff08;一&#xff09;&#xff08;第八部分&#xff09; 一.进程&#xff08;Process&#xff09; 在任务管理器中的进程页下&#xff0c;可以看到进程&#xff0c;任务管理器将进程分为了三类&#xff0c;应用、后台进程、window进程 应用&#xff1a; 打开…

【UE 材质】实现方形渐变、中心渐变材质

步骤 一、实现方形渐变 1. 新建一个材质&#xff0c;材质域选择“后期处理” 2. 通过“Mask”节点单独获取R、G通道&#xff0c;可以看到R通道是从左到右0~1之间的变化&#xff0c;对应U平铺 可以看到G通道是从上到下0~1之间的变化&#xff0c;对应V平铺 3. 完善如下节点 二、…

Leetcode1090. 受标签影响的最大值

思路&#xff1a;根据值从大到小排序&#xff0c;然后在加的时候判断是否达到标签上限即可&#xff0c;一开始想用字典做&#xff0c;但是题目说是集合却连续出现两个8&#xff0c;因此使用元组SortedList进行解决 class Solution:def largestValsFromLabels(self, values: li…

Java后端开发面试题——多线程

创建线程的方式有哪些&#xff1f; 继承Thread类 public class MyThread extends Thread {Overridepublic void run() {System.out.println("MyThread...run...");}public static void main(String[] args) {// 创建MyThread对象MyThread t1 new MyThread() ;MyTh…

aarch64-linux交叉编译libcurl带zlib和openssl

交叉编译libcurl需要依赖zlib和openssl 需要先用aarch64工具链编译zlib和openssl aarch64-linux环境搭建 下载工具链 gcc用于执行交叉编译 gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnusysroot是交叉版本的库文件集合 sysroot-glibc-linaro-2.25-2019.12-aarch64-lin…

iOS逆向进阶:iOS进程间通信方案深入探究与local socket介绍

在移动应用开发中&#xff0c;进程间通信&#xff08;Inter-Process Communication&#xff0c;IPC&#xff09;是一项至关重要的技术&#xff0c;用于不同应用之间的协作和数据共享。在iOS生态系统中&#xff0c;进程和线程是基本的概念&#xff0c;而进程间通信方案则为应用的…

在 Spring Boot 中集成 MinIO 对象存储

MinIO 是一个开源的对象存储服务器&#xff0c;专注于高性能、分布式和兼容S3 API的存储解决方案。本文将介绍如何在 Spring Boot 应用程序中集成 MinIO&#xff0c;以便您可以轻松地将对象存储集成到您的应用中。 安装minio 拉取 minio Docker镜像 docker pull minio/minio创…

Linux:内核解压缩过程简析

文章目录 1. 前言2. 背景3. zImage 的构建过程4. 内核引导过程5. 内核解压缩过程6. 内核加压缩过程小结7. 参考资料 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&#xff0c;作者不做任何承诺。 2. 背景 本文基于 ARM32架构 …