模型训练识别手写数字(一)

  一、模型训练数据集

1. 导入所需库

import numpy as np
from sklearn.datasets import fetch_openml

numpy 是用于数值计算的库。

fetch_openml 是用于从 OpenML 下载数据集的函数。

  2. 获取 MNIST 数据集

X, y = fetch_openml('mnist_784', version=1, return_X_y=True)

fetch_openml('mnist_784', version=1, return_X_y=True) 从 OpenML 下载 MNIST 数据集。X 存储图像数据(784 个特征,28x28 像素的扁平化图像),y 存储对应的标签(数字 0 到 9)。

   3. 将像素值二值化

X[X > 0] = 1

这行代码将 X 中所有大于 0 的像素值设置为 1,二值化处理。这样处理后的图像只有两个值:0(黑色)和 1(白色),有助于简化模型的输入。

   4. 保存数据集

np.save("Data/dataset", X)
np.save("Data/class", y)

np.save("Data/dataset", X) 将图像数据保存为 dataset.npy

np.save("Data/class", y) 将标签数据保存为 class.npy

二、模型训练及预测

1. 导入所需库

import matplotlib.pyplot as plt
import numpy as np
from keras import Sequential
from keras import layers
from keras.api.models import load_model
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

matplotlib.pyplot: 用于绘图和数据可视化。

numpy: 用于处理数组和数据加载。

keras: 用于构建和训练深度学习模型。

sklearn: 提供数据划分和预处理的工具。

f1_score: 用于评估模型性能。

 2. 加载数据

X = np.load("Data/dataset.npy", allow_pickle=True)
y = np.load("Data/class.npy", allow_pickle=True)

使用 numpyload 方法加载训练数据 (X) 和标签 (y)。allow_pickle=True 允许加载包含对象的数组。

 3. One-Hot 编码

onehot = OneHotEncoder(sparse_output=False)
y = onehot.fit_transform(y.reshape(-1, 1))

OneHotEncoder: 将标签转换为独热编码格式,方便用于分类任务。每个标签会被转换为一个二进制数组。 

 4. 划分训练集和测试集

x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=14)

使用 train_test_split 将数据集分为训练集和测试集,通常使用 70%-80% 的数据用于训练,其余用于测试。 

 5. 构建模型

model = Sequential()
model.add(layers.Dense(100, activation='relu', input_shape=(x_train.shape[1],)))
model.add(layers.Dense(y.shape[1], activation='softmax'))  

Sequential: 表示模型是线性的,按顺序堆叠各个层。

Dense: 添加全连接层,第一层有 100 个神经元,使用 ReLU 激活函数;第二层为输出层,使用 Softmax 激活函数,适合多类分类任务。

 6. 编译模型

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

使用 Adam 优化器,损失函数为交叉熵(适合多分类),并监控准确率。 

 7. 训练模型

model.fit(x_train, y_train, epochs=100, batch_size=32, verbose=1)

 训练模型 100 个周期,批量大小为 32,verbose=1 表示输出训练过程的信息。

 8. 评估模型

predictions = model.predict(x_test)
predictions_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)print("F-score: {0:.2f}".format(f1_score(y_test_classes, predictions_classes, average='micro')))

使用测试集进行预测,并计算 F-score 作为评估指标。np.argmax 用于获取每个样本预测概率最高的类。 

 9. 保存模型

model.save("my_model.h5")

将训练好的模型保存到文件 my_model.h5 中,以便后续加载和使用。 

 10. 加载模型

loaded_model = load_model("my_model.h5")

 加载之前保存的模型,以便进行预测。

 11. 进行预测

predictions = loaded_model.predict(x_test)

使用加载的模型对测试集进行预测,获取每个样本的预测结果。

 12. 获取预测和真实标签

y_pred_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)

 使用 np.argmax 从预测结果和真实标签中获取每个样本的类别索引。

  13. 可视化预测结果

plt.figure(figsize=(12, 6))for i in range(20):plt.subplot(4, 5, i + 1)plt.imshow(x_test[i].reshape(28, 28), cmap='gray')  # 假设输入是28x28的图像plt.title(f'True: {y_test_classes[i]}\nPred: {y_pred_classes[i]}')plt.axis('off')plt.tight_layout()
plt.show()

创建一个图形窗口,设置大小为 12x6。

使用 subplot 在 4 行 5 列的网格中绘制 20 个图像。

每个子图中显示测试样本的图像、真实标签和预测标签。

imshow 将图像进行灰度显示,axis('off') 隐藏坐标轴。

tight_layout() 调整子图参数,以避免重叠。

show() 显示图形。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/458822.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot与Flyway实现自动化数据库版本控制

一、为什么使用Flyway 最简单的一个项目是一个软件连接到一个数据库,但是大多数项目中我们不仅要处理我们开发环境的副本,还需要处理其他很多副本。例如:开发环境、测试环境、生产环境。想到数据库管理,我们立刻就能想到一系列问…

Ovis原理解读: 多模态大语言模型的结构嵌入对齐

论文:https://arxiv.org/pdf/2405.20797 github:https://github.com/AIDC-AI/Ovis 在多模态大语言模型 (MLLM) 中,不同的嵌入策略有显著的区别。以下是使用基于连接器的方法与 Ovis 方法的比较: 基于连接器的方法-优缺点(connector-based …

斜杠往哪斜、路径绝对还是相对,终端目录切换不再迷茫

目录 路径表示绝对路径相对路径两者区别 路径中斜杠的用法正反斜杠对比表一个常见的问题 终端切换目录常用cd指令同一盘符内跨盘符 路径表示 在计算机文件系统中,路径是用来指定文件或目录位置的一种方式。路径可以是绝对路径或相对路径: 绝对路径 绝…

cmake 编译 vtk

1. 下载 VTK 源码 vtk 源码,点击:官网下载 在官网选择合适的版本下载,这里下载的是 vtk 8.2.0 版本 2. 下载 CMake CMake 工具,点击:镜像下载 下载版本比较新的 CMake 版本,这里使用的是 CMake 3.29.2 版…

在不支持AVX的linux上使用PaddleOCR

背景 公司的虚拟机CPU居然不支持avx, 默认的paddlepaddle的cpu版本又需要有支持avx才行,还想用PaddleOCR有啥办法呢? 是否支持avx lscpu | grep avx 支持avx的话,会显示相关信息 如果不支持的话,python运行时导入paddle会报错 怎么办呢 方案一 找公司it,看看虚拟机为什么…

C++基础:constexpr,类型转换和选择语句

constexpr 提到constexpr&#xff0c;我们会发现它和const类比 常和const类比constexpr符号常量必须给定一个在编译时已知的值&#xff0c; 若某个变量初始化时的值在编译时未知&#xff0c;但初始化后绝不变。 #include<iostream> #include<vector> #include&l…

【机器学习基础】激活函数

激活函数 1. Sigmoid函数2. Tanh&#xff08;双曲正切&#xff09;函数3. ReLU函数4. Leaky ReLU函数 1. Sigmoid函数 观察导数图像在我们深度学习里面&#xff0c;导数是为了求参数W和B&#xff0c;W和B是在我们模型model确定之后&#xff0c;找出一组最优的W和B&#xff0c;使…

Go使用exec.Command() 执行脚本时出现:file or directory not found

使用 Go 提供的 exec.Command() 执行脚本时出现了未找到脚本的 bug&#xff0c;三个排查思路 &#xff1a; exec.Command(execName, args…) 脚本名字不允许相对路径 exec.Command(execName, args…) execName 只能有脚本名&#xff0c;不允许出现参数 如果你是使用 Windows …

Python爬虫:商品详情的“八卦记者”

亲爱的代码侦探们&#xff0c;今天咱们不聊那些让人头秃的bug&#xff0c;也不谈那些让人眼花的架构图。咱们来聊聊那些在代码世界里挖掘商品秘密的“八卦记者”——Python爬虫。 Python爬虫&#xff1a;商品详情的“八卦记者” 想象一下&#xff0c;你在代码的世界里&#xf…

[笔记] ffmpeg docker编译环境搭建

文章目录 环境参考dockerfile 文件步骤常见问题docker 构建镜像出现 INTERNAL_ERROR 失败? 总结 环境 docker 环境 系统centos 7.9 (无所谓了 你用docker编译就无所谓系统了) ffmpeg3.3 参考 https://blog.csdn.net/jiedichina/article/details/71438112 dockerfile 文件 …

《等保测评新视角:安全与发展的双赢之道》

在数字化转型的浪潮中&#xff0c;企业面临的不仅是技术革新的挑战&#xff0c;更有信息安全的严峻考验。等保测评&#xff0c;作为国家网络安全等级保护的一项重要措施&#xff0c;不仅为企业的安全护航&#xff0c;更成为推动企业高质量发展的新引擎。本文将从全新的视角&…

如何将markdown文件转换为pdf

最近笔者在用vscode写markdown&#xff0c;但是提交时往往需要交pdf。所以就涉及到如何将markdown转化为pdf格式。 首先&#xff0c;需要在vscode上安装插件 markdown Preview Enhanced 之后在vscode的右上角即可看到下述图标&#xff0c;点击&#xff0c;vscode右半面就会显示…

【论文阅读】PGAN

1. WHY 问题 图像超分辨率一直是一个热门研究课题&#xff0c;具有重要的应用价值。基于生成对抗网络GAN的单幅图像超分辨率方法显示重建图像与人类视觉特征更一致。因此&#xff0c;基于 GAN 的网络优化已成为图像超分辨率的主流。然而&#xff0c;一些最新研究表明&#xf…

【JIT/极态云】技术文档--函数设计

一、简介 函数是计算机编程中非常重要的概念。它是一段代码&#xff0c;可以在程序中多次调用&#xff0c;用于完成特定的任务。 函数通常接受输入参数&#xff0c;执行特定的操作&#xff0c;并返回一个结果。这个结果可以被程序中的其他代码使用。 二、新建函数 在函数列表…

Springboot整合spring-boot-starter-data-elasticsearch

前言 <font style"color:rgb(36, 41, 47);">spring-boot-starter-data-elasticsearch</font> 是 Spring Boot 提供的一个起始依赖&#xff0c;旨在简化与 Elasticsearch 交互的开发过程。它集成了 Spring Data Elasticsearch&#xff0c;提供了一套完整…

mysql-Innodb锁相关内容

1、InnoDB存储引擎包含的锁类型 共享锁&#xff08;S锁&#xff09;和排他锁&#xff08;X锁&#xff09;意向锁记录锁间隙锁Next-key锁插入意向锁Auto-INC 锁空间索引的谓词锁 2、共享锁&#xff08;S锁&#xff09;和排他锁&#xff08;X锁&#xff09;-- 锁定数据行 共享…

使用Git进行团队协作开发

使用Git进行团队协作开发 Git简介 安装Git 在Windows上安装Git 在macOS上安装Git 在Linux上安装Git 设置Git用户信息 创建Git仓库 基本Git命令 添加文件 提交更改 查看状态 克隆仓库 推送更改 获取更改 分支管理 创建分支 切换分支 合并分支 删除分支 解决合并冲突 检查冲突…

docker安装、设置非sudo执行、卸载

安装 sudo snap install docker 设置docker非sudo执行 sudo groupadd docker sudo usermod -aG docker $USER newgrp docker sudo chown root:docker /var/run/docker.sock 卸载docker 1.删除docker及安装时自动安装的所有包 apt-get autoremove docker docker-ce docker-…

数据结构_二叉树

二叉树的性质 满二叉树 完全二叉树 完全二叉树的特点 二叉树的存储结构 顺序存储 链式存储 二叉链表 三叉链表 二叉树遍历算法 先序遍历 先序遍历&#xff1a;ABDC 中序遍历 后序遍历 层次遍历

Win11安装基于WSL2的Ubuntu

1. 概述 趁着还没有完全忘记&#xff0c;详细记录一下在Win11下安装基于WSL2的Ubuntu的详细过程。不得不说WSL2现在被微软开发的比较强大了&#xff0c;还是很值得安装和使用的&#xff0c;笔者就通过WSL2安装的Ubuntu成功搭建了ROS环境。 2. 详论 2.1 子系统安装 在Win11搜…