01、Tensorflow实现二元手写数字识别

01、Tensorflow实现二元手写数字识别(二分类问题)

01、Tensorflow实现二元手写数字识别(二分类问题)
02、Tensorflow实现手写数字识别(数字0-9)


开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0

1、识别目标

识别手写仅仅是为了区分手写的0和1,所以实际上是一个二分类问题。

2、Tensorflow算法实现

STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。

from sklearn.metrics import accuracy_score:这是从scikit-learn库中引入accuracy_score函数。这个函数用于计算分类准确率,常用于评估分类模型的性能。


STEP2:屏蔽无用警告并允许中文

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

logging.getLogger(“tensorflow”).setLevel(logging.ERROR):这行代码用于设置 TensorFlow 的日志级别为 ERROR。这意味着只有当 TensorFlow 中发生错误时,才会在日志中输出相关信息。较低级别的日志信息(如 WARNING、INFO、DEBUG)将被忽略。

tf.autograph.set_verbosity(0):这行代码用于设置 TensorFlow 的自动图形(Autograph)日志的冗长级别为 0。这意味着在将 Python 代码转换为 TensorFlow 图形代码时,将不会输出任何日志信息。这有助于减少日志噪音,使日志更加干净。

warnings.simplefilter(action=‘ignore’,category=FutureWarning):这行代码用于忽略所有 FutureWarning 类型的警告。在 Python中,当使用某些即将过时或未来版本中可能发生变化的特性时,通常会发出 FutureWarning。通过设置action=‘ignore’,代码将不会输出这类警告,使控制台输出更加干净。

plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:这行代码用于设置 matplotlib 中的默认无衬线字体为 SimHei。SimHei 是一种常用于显示中文的字体,这样设置后,matplotlib 将在绘图时使用 SimHei 字体来显示中文,从而避免中文乱码问题。

plt.rcParams[‘axes.unicode_minus’]=False:这行代码用于解决 matplotlib
中负号显示异常的问题。默认情况下,matplotlib 可能无法正确显示负号,将其设置为 False 可以使用 ASCII字符作为负号,从而正常显示。


STEP3:导入并划分数据集

划分10%作为测试:

X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

STEP4:模型构建与训练

# 构建模型,三层模型进行分类,第一层输入100个神经元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)

STEP5:结果可视化与打印准确度信息
原始的输入的数据集是400 * 1000的数组,共包含1000个手写数字的数据,其中400为20*20像素的图片,因此对每个400的数组进行reshape((20, 20))可以得到原始的图片进而绘图。

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

3、运行结果

按照最初的划分,数据集包含1000个数据,划分10%为测试集,也就是100个数据。结果可视化随机选择其中的64个数据绘图,每个图像的上方标明了其真实标签和预测的结果,这个是一个非常简单的示例,准确度还是非常高的。
在这里插入图片描述

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现二元手写数字识别(二分类问题)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_scorelogging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_data/X.npy")y = np.load("Handwritten_Digit_Recognition_data/y.npy")X = X[0:1000]y = y[0:1000]return X, y# 加载数据集,查看数据集大小,可以看到有1000个数据集,每个输入是20*20=400大小的图片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 下面画图,随便从原数据取出几个画图,可以注释
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # 将1*400的数据转换为20*20的图像格式
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
# plt.show()# 构建模型,三层模型进行分类,第一层输入25个神经元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/209924.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Redis构建简易社交网站(1)-创建用户与动态界面

目的 本文目的:实现简易社交网站中创建新用户和创建新动态功能。(完整代码附在文章末尾) 相关知识 本文将教会你掌握:1.redis基本命令,2.python基本命令。 redis基本命令 hget:从哈希中获取指定域的值…

java后端技术演变杂谈(未完结)

1.0版本javaWeb:原始servletjspjsbc 早期的jsp:htmljava,页面先在后端被解析,里面的java代码动态渲染完成后,成为纯html,再通过服务器发送给浏览器显示。 缺点: 服务器压力很大,因为…

python提取通话记录中的时间信息

您需要安装适合中文的SpaCy模型。您可以通过运行 pip install spacypython -m spacy download zh_core_web_sm来安装和下载所需的模型。 import spacy# 加载中文模型 nlp spacy.load(zh_core_web_sm)# 示例电话记录文本 text """ Agent: 今天我们解决一下这…

QT之QString

QT之QString 添加容器 点击栅格布局 添加容器,进行栅格布局 布局总结:每一个模块放在一个Group中,排放完之后,进行栅格布局。多个Group进行并排时,先将各个模块进行栅格布局,然后都选中进行垂直布…

华清远见嵌入式学习——C++——作业3

作业要求&#xff1a; 代码&#xff1a; #include <iostream>using namespace std;class Per { private:string name;int age;double *high;double *weight; public://有参构造函数Per(string n,int a,double h,double w):name(n),age(a),high(new double(h)),weight(ne…

14 网关实战:网关聚合API文档

上节课介绍了网关层的认证鉴权,今天这节介绍一下网关层如何聚合API接口文文档。 为什么需要聚合API接口文档? 大型微服务系统模块众多,木谷博客系统就有9个,如果这些服务的接口地址没有一个统一,那么客户端将要保存每个服务的接口地址,这个肯定是不现实。 先来看一下A…

11. 哈希冲突

上一节提到&#xff0c;通常情况下哈希函数的输入空间远大于输出空间&#xff0c;因此理论上哈希冲突是不可避免的。比如&#xff0c;输入空间为全体整数&#xff0c;输出空间为数组容量大小&#xff0c;则必然有多个整数映射至同一桶索引。 哈希冲突会导致查询结果错误&#…

机器学习的复习笔记3-回归的细谈

一、回归的细分 机器学习中的回归问题是一种用于预测连续型输出变量的任务。回归问题的类型和特点如下&#xff1a; 线性回归&#xff08;Linear Regression&#xff09;&#xff1a;线性回归是回归问题中最简单的一种方法。它假设自变量与因变量之间存在线性关系&#xff0c…

【Unity动画】状态机添加参数控制动画切换(Animator Controller)

Unity - 手册&#xff1a;动画参数 在Unity中&#xff0c;动画状态的切换是通过Animator Controller中的过渡&#xff08;Transition&#xff09;来实现的。过渡是状态之间的连接&#xff0c;控制过渡一般都是靠调用代码参数 我们来实现一个案例&#xff1a; 创建动画状态机&a…

vscode中使用luaide-lite插件断点调试cocos2dx-lua

使用quick-cocos2dx-lua&#xff0c;用了众多插件&#xff0c;包括免费的BabeLua,VS调试太慢&#xff0c;vscode上的免费的EmmyLua, 还有收费的luaide&#xff0c;都没搞出来&#xff0c;唯独这个免费luaide-lite用成功了&#xff0c;步骤也简单&#xff0c;可以断点调试&#…

数据结构第六课 -----链式二叉树的实现

作者前言 &#x1f382; ✨✨✨✨✨✨&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f382; ​&#x1f382; 作者介绍&#xff1a; &#x1f382;&#x1f382; &#x1f382; &#x1f389;&#x1f389;&#x1f389…

Java SpringBoot Controller常见写法

文章目录 环境Controller调用脚本运行结果总结 环境 系统: windows 11 工具: java, idea, git bash Controller 接口常见有以下几种方式 其中&#xff1a; Tobj 调用脚本 我的是windows 系统&#xff0c;使用 git bash 窗口运行, 用 cmd 或者 power shell 会有问题 curl …

C盘分析文件大小的软件

https://sourceforge.net/projects/windirstat/ 上面是windirstat的下载链接 界面是这样的&#xff1a; 选择C盘或者D盘&#xff0c;点击OK&#xff0c;就可以分析了 然后就可以看到哪些占比最高&#xff0c;可以针对性的清理

C#网络编程UDP程序设计(UdpClient类)

目录 一、UdpClient类 二、示例 1.源码 &#xff08;1&#xff09;Client &#xff08;2&#xff09;Server 2.生成 &#xff08;1&#xff09;先启动服务器&#xff0c;发送广播信息 &#xff08;2&#xff09;再开启客户端接听 UDP是user datagram protocol的简称&a…

整数的立方和

系列文章目录 进阶的卡莎C++_睡觉觉觉得的博客-CSDN博客数1的个数_睡觉觉觉得的博客-CSDN博客双精度浮点数的输入输出_睡觉觉觉得的博客-CSDN博客足球联赛积分_睡觉觉觉得的博客-CSDN博客大减价(一级)_睡觉觉觉得的博客-CSDN博客小写字母的判断_睡觉觉觉得的博客-CSDN博客纸币(…

bad_python

攻防世界 (xctf.org.cn) 前戏 下载文件&#xff0c;解压完成后是这个 一个pyc文件 这里要用到python的反编译 要用到的工具有两个 1.python自带的uncompyle6 2.pycdc文件——比uncompyle6强大一点 我们一个一个来尝试一下 uncompyle6&#xff1a; 我是直接在pycharm里面…

uniapp在H5端实现PDF和视频的上传、预览、下载

上传 上传页面 <u-form-item :label"(form.ququ3 1 ? 参培 : form.ququ3 2 ? 授课 : ) 证明材料" prop"ququ6" required><u-button click"upload" slot"right" type"primary" icon"arrow-upward" t…

设计模式-结构型模式之代理设计模式

文章目录 八、代理设计模式 八、代理设计模式 代理设计模式通过代理控制对象的访问&#xff0c;可以详细访问某个对象的方法&#xff0c;在这个方法调用处理&#xff0c;或调用后处理。既(AOP微实现) 。 代理有分静态代理和动态代理&#xff1a; 静态代理&#xff1a;在程序…

集成开发环境PyCharm的使用【侯小啾python基础领航计划 系列(三)】

集成开发环境 PyCharm 的使用【侯小啾python基础领航计划 系列(三)】 大家好,我是博主侯小啾, 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹…

QT 中 QProgressDialog 进度条窗口 备查

基础API //两个构造函数 QProgressDialog::QProgressDialog(QWidget *parent nullptr, Qt::WindowFlags f Qt::WindowFlags());QProgressDialog::QProgressDialog(const QString &labelText, const QString &cancelButtonText, int minimum, int maximum, QWidget *…