基于TensorFlow的手写体数字识别训练与测试

需求:

  • 选择一个最简单的细分方向,初步了解AI图像识别的训练、测试过程
  • TensorFlow、PyTorch、c,三种代码方案,先从TensorFlow入手
  • 探讨最基本问题的优化问题

总结:

  • 基于TensorFlow的python代码库自带了mnist 训练数据集、测试数据集。避免了自己去收集图像、标注的问题。
  • 利用chatgpt逐步完善代码,输出图像(字符方式、bmp方式)辅助分析
  • x为0-9的图像、y为对应数字标签0-9,train训练集60000个,test测试集10000个
  • 实际测试结果能达到98%成功识别率,但是剩下的2%错得也很离谱,有优化的空间。
  • 每次训练、测试的结果,存在差别,并不是完全一样的结果,TensorFlow算法中可能存在随机数
  • 测试失败的数字2中,部分与训练集比较类似,直观看起来不应该失败

代码和注释

# 环境: 20241030 win10 vs2022 python3.9.13
# 安装tensorflow: pip install tensorflow
# vs2022时,在 C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python39_64\Scripts 下运行import os
import numpy as np
import PIL.Image as Image# 显示图像
#import matplotlib.pyplot as plt# oneDNN: Intel 推出的一款深度学习性能优化库,可以加速深度学习计算。
# 1启用/0禁用 oneDNN 优化
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'# 第一次import tensorflow耗时较长
import tensorflow as tf
from tensorflow.keras import layers, models# 检查GPU是否可用
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))def display_mnist_image_console(image):# 设置字符映射,空格代表最暗,#代表最亮chars = " .:-=+*#%@"# 归一化图像到0-9的整数范围normalized_image = (image / 255 * (len(chars) - 1)).astype(int)# 使用字符映射显示图像for row in normalized_image:print("".join(chars[pixel] for pixel in row))def save_mnist_image_as_bmp(image, filename="1.bmp"):"""将MNIST图像保存为BMP格式Args:image: MNIST图像数据,形状为(28, 28)filename: 保存的文件名"""# 确保图像数据在0-255范围内image = np.clip(image, 0, 255).astype(np.uint8)# 将图像数据转换为PIL Image对象img = Image.fromarray(image, 'L')  # 'L'表示灰度图像# 保存图像img.save(filename)# 定义MNIST数据集的下载地址
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"# 检查本地是否存在MNIST数据集文件
data_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(data_dir, "mnist.npz")if not os.path.exists(data_file):print(f"本地未找到MNIST数据集,正在从 {mnist_url} 下载...")# 使用tensorflow自带的下载函数下载数据集tf.keras.utils.get_file(filename="mnist.npz", origin=mnist_url, extract=True)
else:print(f"本地已存在MNIST数据集,将使用本地文件 {data_file}")# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test_original = x_test.copy()  # 创建 x_test 的备份
# 下载mnist.npz文件,解压后是x_train.npy x_test.npy等4个文件
# .npy 文件是 NumPy(Numerical Python)的一种自描述二进制文件格式。# 输出数据集基本信息
# x图像、y标签,train训练集60000个,test测试集10000
print(f"训练集图像形状:{x_train.shape},数据类型{x_train.dtype};标签形状:{y_train.shape},数据类型{y_train.dtype}")
print(f"测试集图像形状:{x_test.shape},数据类型{x_test.dtype};标签形状:{y_test.shape},数据类型{y_test.dtype}")# 输出更多详细信息
print(f"\n标签{y_train[0]}对应的图像示例:")
#print(x_train[0]) # 这个图像示例是数字528*28的灰度图
display_mnist_image_console(x_train[0])
save_mnist_image_as_bmp(x_train[0])# 输出图像的最小值和最大值
print(f"\n训练集图像像素值的最小值:{np.min(x_train)};最大值:{np.max(x_train)}") # 0 - 255# x图像 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0  # 定义一个简单的CNN模型
model = models.Sequential([ # Sequential: 创建一个顺序模型,即神经网络的层按顺序堆叠。layers.Flatten(input_shape=(28, 28)), # Flatten: 将输入的 28x28 的二维图像展平为一维向量,以便输入到全连接层。layers.Dense(128, activation='relu'), # Dense: 全连接层,神经元之间全连接。128: 输出神经元的数量,即隐藏层的神经元数量。activation='relu': 使用 ReLU 作为激活函数,引入非线性。layers.Dropout(0.2), # Dropout 层,随机丢弃部分神经元,防止过拟合。每次训练时,随机丢弃 20% 的神经元。layers.Dense(10, activation='softmax') # Dense(10, activation='softmax'): 输出层,有 10 个神经元,对应 10 个数字分类。使用 softmax 激活函数,将输出转换为概率分布。
])# 编译模型
model.compile(optimizer='adam', # 使用 Adam 优化器,一种常用的优化算法。loss='sparse_categorical_crossentropy', # 使用稀疏分类交叉熵作为损失函数,适用于多分类问题且标签是整数的情况。metrics=['accuracy']) # 评估指标为准确率。# 训练模型
model.fit(x_train, y_train, epochs=5) # 训练 5 个 epoch,每个 epoch 遍历一遍整个训练集。训练5次。# 评估模型的性能,并输出损失和准确率。
# 损失(loss): 模型在测试集上的平均损失值,反映了模型预测值与真实值之间的差异。损失越小,说明模型预测越准确。
# 准确率(accuracy): 模型在测试集上预测正确的样本比例,直接反映了模型的分类性能。
#model.evaluate(x_test, y_test) # 评估模型的性能,并输出损失和准确率。
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"\n模型评估 - 损失: {loss:.4f}, 准确率: {accuracy:.4f}")# 预测测试集标签
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)# 初始化错误样本计数
wrong_count = 0
total_count = len(x_test)# 遍历测试集,输出识别错误的样本
print("\n识别错误的样本:")
for i in range(len(x_test)):if predicted_labels[i] != y_test[i]:  # 判断是否识别错误wrong_count += 1print(f"\n样本索引: {i} 模型预测结果: {predicted_labels[i]}, 正确结果: {y_test[i]}")display_mnist_image_console(x_test_original[i])  # 显示图像# 输出错误样本总数和总样本数
print(f"\n总共 {total_count} 个样本,识别错误 {wrong_count} 个")
# 总共 10000 个样本,识别错误 222 个。  部分识别错误的明显不应该错。

示例图像:
在这里插入图片描述

识别错误的图像举例:
在这里插入图片描述
在这里插入图片描述
输出指定img列表到bmp文件

def save_images_to_bmp(images, labels, filename, max_per_row=50):"""将图像保存到 BMP 文件中,每行最多 max_per_row 张图像。:param images: 图像数组,形状为 (n, 28, 28):param labels: 标签数组,形状为 (n,):param filename: 保存的 BMP 文件名:param max_per_row: 每行最大图像数量"""img_count = len(images)rows = (img_count + max_per_row - 1) // max_per_rowimg_width, img_height = 28, 28# 创建画布canvas_width = max_per_row * img_widthcanvas_height = rows * img_heightcanvas = Image.new("L", (canvas_width, canvas_height), color=255)  # 灰度图# 绘制每张图片for idx, img in enumerate(images):x_offset = (idx % max_per_row) * img_widthy_offset = (idx // max_per_row) * img_height# img_pil = Image.fromarray((img * 255).astype(np.uint8))  # 恢复像素值范围 0-255img_pil = Image.fromarray(img, 'L')canvas.paste(img_pil, (x_offset, y_offset))# 保存到文件canvas.save(filename)print(f"保存 {filename} 成功!")# 遍历训练集,分类存储
print("\n遍历训练集,分类存储:")
train_images = {i: [] for i in range(10)}
for i in range(len(x_train)):label = y_train[i] # 标签train_images[label].append(x_train[i])# 保存训练集图像
for digit in range(10):# 保存正确分类的样本if train_images[digit]:save_images_to_bmp(train_images[digit],[digit] * len(train_images[digit]),f"train_{digit}.bmp")# 预测测试集标签
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)# 初始化错误样本计数
wrong_count = 0
total_count = len(x_test)
# 初始化存储字典
correct_images = {i: [] for i in range(10)}
wrong_images = {i: [] for i in range(10)}# # 遍历测试集,输出识别错误的样本
# print("\n识别错误的样本:")
# for i in range(len(x_test)):
#     if predicted_labels[i] != y_test[i]:  # 判断是否识别错误
#         wrong_count += 1
#         print(f"\n样本索引: {i} 模型预测结果: {predicted_labels[i]}, 正确结果: {y_test[i]}")
#         display_mnist_image_console(x_test_original[i])  # 显示图像# 遍历测试集,分类存储识别结果
print("\n识别错误的样本统计汇总:")
for i in range(len(x_test)):label = y_test[i]                # 真实标签predicted = predicted_labels[i]  # 模型预测结果if predicted == label:correct_images[label].append(x_test_original[i])else:wrong_images[label].append(x_test_original[i])wrong_count += 1print(f"样本索引: {i} 模型预测结果: {predicted_labels[i]}, 正确结果: {y_test[i]}")# display_mnist_image_console(x_test_original[i])  # 显示图像# 保存图像
for digit in range(10):# 保存正确分类的样本if correct_images[digit]:save_images_to_bmp(correct_images[digit],[digit] * len(correct_images[digit]),f"test_{digit}.bmp")# 保存错误分类的样本if wrong_images[digit]:save_images_to_bmp(wrong_images[digit],[digit] * len(wrong_images[digit]),f"test_error_{digit}.bmp",max_per_row=10)

以数字2为例,以下分别为训练集图像、测试集通过的图像、测试集失败的图像:
训练集
测试集通过的
测试集失败的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/481871.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLO系列论文综述(从YOLOv1到YOLOv11)【第11篇:YOLO变体——YOLO+Transformers、DAMO、PP、NAS】

YOLO变体 1 DAMO-YOLO2 PP-YOLO, PP-YOLOv2, and PP-YOLOE2.1 PP-YOLO数据增强和预处理2.2 PP-YOLOv22.3 PP-YOLOE 3 YOLO-NAS4 YOLO Transformers5 YOLOv1-v8及变体网络结构总结 YOLO系列博文: 【第1篇:概述物体检测算法发展史、YOLO应用领域、评价指标…

SE16N 外键校验报错问题

问题: SE16N维护时,偶尔有一些莫名奇妙的校验报错,条目XX在表XX中不存在,但是实际数据时存在的。 分析: DEBUG过程中,定位到数据校验部分,发现当外键定义的关联字段中存在某些不在对应维护表中…

【数据结构】二叉搜索树(二叉排序树)

🌟🌟作者主页:ephemerals__ 🌟🌟所属专栏:数据结构 目录 前言 一、什么是二叉搜索树 二、二叉搜索树的实现 节点 属性和接口的声明 插入 查找 删除 拷贝构造 析构 中序遍历 三、二叉搜索树的…

【接口自动化测试】一文从3000字从0到1详解接口测试用例设计

接口自动化测试是软件测试中的一种重要手段,它能有效提高测试效率和测试覆盖率。在进行接口自动化测试之前,首先需要进行接口测试用例的设计。本文将从0到1详细且规范的介绍接口测试用例设计的过程,帮助读者快速掌握这一技能。 一、了解接口…

使用 PDF API 合并 PDF 文件

内容来源: 如何在 Mac 上合并 PDF 文件 1. 注册与认证 您可以注册一个免费的 ComPDFKit API 帐户,该帐户允许您在 30 天内免费无限制地处理 1,000 多个文档。 ComPDFKit API 使用 JSON Web Tokens 方法进行安全身份验证。从控制面板获取您的公钥和密钥&…

微服务即时通讯系统的实现(服务端)----(2)

目录 1. 语音识别子服务的实现1.1 功能设计1.2 模块划分1.3 模块功能示意图1.4 接口的实现 2. 文件存储子服务的实现2.1 功能设计2.2 模块划分2.3 模块功能示意图2.4 接口的实现 3. 用户管理子服务的实现3.1 功能设计3.2 模块划分3.3 功能模块示意图3.4 数据管理3.4.1 关系数据…

Scala—列表(可变ListBuffer、不可变List)用法详解

Scala集合概述-链接 大家可以点击上方链接,先对Scala的集合有一个整体的概念🤣🤣🤣 在 Scala 中,列表(List)分为不可变列表(List)和可变列表(ListBuffer&…

Android 系统之Init进程分析

1、Init进程流程 2、Init细节逻辑 2.1 Init触发shutdown init进程触发系统重启是一个很合理的逻辑,为什么合理? init进程是android世界的一切基石,如果android世界的某些服务或者进程出现异常,那么会导致整个系统无法正常使用…

NVR录像机汇聚管理EasyNVR多个NVR同时管理基于B/S架构的技术特点与能力应用

EasyNVR视频融合平台基于云边端协同设计,能够轻松接入并管理海量的视频数据。该平台兼容性强、拓展灵活,提供了视频监控直播、录像存储、云存储服务、回放检索以及平台级联等一系列功能。B/S架构使得EasyNVR实现了视频监控的多元化兼容与高效管理。 其采…

使用ffmpeg命令实现视频文件间隔提取帧图片

将视频按每隔五秒从视频中提取一张图片 使用 ffmpeg 工具,通过设置 -vf(视频过滤器)和 -vsync 选项 命令格式 ffmpeg -i input_video.mp4 -vf "fps1/5" output_%03d.png 解释: -i input_video.mp4:指定输…

Java安全—原生反序列化重写方法链条分析触发类

前言 在Java安全中反序列化是一个非常重要点,有原生态的反序列化,还有一些特定漏洞情况下的。今天主要讲一下原生态的反序列化,这部分内容对于没Java基础的来说可能有点难,包括我。 序列化与反序列化 序列化:将内存…

现代网络架构PCI DSS合规范围确定和网络分割措施实施探讨

本文为atsec和作者技术共享类文章,旨在共同探讨信息安全业界的相关话题。未经许可,任何单位及个人不得以任何方式或理由对本文的任何内容进行修改。转载请注明:atsec信息安全和作者名称 1 引言 支付卡行业数据安全标准 (P…

docker快速部署gitlab

文章目录 场景部署步骤默认账号密码效果 场景 新增了一台机器, 在初始化本地开发环境,docker快速部署gitlab 部署步骤 编写dockerfile version: 3.7services:gitlab:image: gitlab/gitlab-ce:latestcontainer_name: gitlabrestart: alwayshostname: gitlabenviron…

Kubernetes 01

MESOS:APACHE 分布式资源管理框架 2019-5 Twitter退出,转向使用Kubernetes Docker Swarm 与Docker绑定,只对Docker的资源管理框架,阿里云默认Kubernetes Kubernetes:Google 10年的容器化基础框架,borg…

芯科科技率先支持Matter 1.4,推动智能家居迈向新高度

Matter 1.4引入核心增强功能、支持新设备类型,持续推进智能家居互联互通 近日,连接标准联盟(Connectivity Standard Alliance,CSA)发布了Matter 1.4标准版本。作为连接标准联盟的重要成员之一,以及Matter标…

Redis 分布式锁实现方案

一、概述 分布式锁,即分布式系统中的锁。在单体应用中我们通过锁解决的是控制共享资源访问的问题,而分布式锁,就是解决了分布式系统中控制共享资源访问的问题。与单体应用不同的是,分布式系统中竞争共享资源的最小粒度从线程升级…

数据结构-最小生成树

一.最小生成树的定义 从V个顶点的图里生成的一颗树,这颗树有V个顶点是连通的,有V-1条边,并且边的权值和是最小的,而且不能有回路 二.Prim算法 Prim算法又叫加点法,算法比较适合稠密图 每次把边权最小的顶点加入到树中&#xff0…

ASP.NET Web(.Net Framework) Http服务器搭建以及IIS站点发布

ASP.NET Web(.Net Framework) Http服务器搭建以及IIS站点发布 介绍创建ASP.NET Web (.Net Framework)http服务器创建项目创建脚本部署Http站点服务器测试 Get测试编写刚才的TestWebController.cs代码如下测试写法1测试写法2 Post测…

【AI系统】昇腾 AI 架构介绍

昇腾 AI 架构介绍 昇腾计算的基础软硬件是产业的核⼼,也是 AI 计算能⼒的来源。华为,作为昇腾计算产业⽣态的⼀员,是基础软硬件系统的核⼼贡献者。昇腾计算软硬件包括硬件系统、基础软件和应⽤使能等。 而本书介绍的 AI 系统整体架构&#…

org.apache.commons.lang3包下的StringUtils工具类的使用

前言 相信平时在写项目的时候,一定使用到StringUtils.isEmpty();StringUtils.isBlank();但是你真的了解他们吗? 也许你两个都不知道,也许你除了isEmpty/isNotEmpty/isNotBlank/isBlank外,并不知道还有isAnyEmpty/isNon…