AI开发:用模型来识别手写数字的完整教程含源码 - Python 机器学习

今天一起来学习scikit-learn

scikit-learn 是一个强大的 Python 机器学习库,提供多种分类、回归、聚类算法,适用于从数据预处理到模型评估的全流程。它支持简单一致的 API,适合快速构建和测试模型。

官方地址在这里,记得Mark  很有用: https://scikit-learn.org/dev/index.html

scikit-learn 在手写数字识别方面具有以下特点:

  1. 提供内置的手写数字数据集(digits),包含 1797 个 8×8 像素的灰度数字图像。
  2. 支持多种分类算法(如 SVM、决策树、kNN 等),便于快速模型选择和评估。
  3. 内置工具可进行特征提取、数据预处理和模型训练,简化流程。
  4. 提供易用接口,适合初学者学习和研究机器学习算法在数字识别中的应用。

几天我们要使用这个库来识别一张图片中的手写数字,基本的业务逻辑如下图:

这里要讲一下,AI开发应用,不需要熟知底层的模型基础技术和知识,只需要掌握库和模型的应用。我们先来看一下第一步数据加载

No. 1 加载数据:

这里定义了一个函数 load_dataset ,作用是加载和返回一个内置的手写数字数据集,供后续的机器学习模型训练和测试使用。

# 模块1: 数据加载和预处理
def load_dataset():"""加载数字数据集"""digits = datasets.load_digits()return digits.data, digits.target

详细解释

  1. 数据集来源

    • 使用的是 scikit-learn 的内置数据集 datasets.load_digits()
    • 这是一个用于手写数字识别的经典数据集,包含 0 到 9 的手写数字样本。
  2. 数据集内容

    • 数据集包含 64维特征(每个数字图像为 8×8 像素,像素值被展平为一维数组)。
    • 标签是对应数字的值(例如,0, 1, 2...9)。
  3. 返回值

    • digits.data: 2D 数组,表示数据集中的所有特征(每行是一个样本,每列是一个特征)。
    • digits.target: 一维数组,表示每个样本对应的真实标签。
  4. 函数的作用

    • 加载数据集,直接返回特征数据和目标标签,便于分割数据集或传递给模型训练函数。

举例输出

调用 load_dataset 的代码:

X, y = load_dataset()print("特征数据 X 的形状:", X.shape)
print("标签数据 y 的形状:", y.shape)
print("第一个样本的特征:\n", X[0])
print("第一个样本的标签:", y[0])

输出示例

特征数据 X 的形状: (1797, 64)
标签数据 y 的形状: (1797,)
第一个样本的特征:[ 0.  0.  5. 13.  9.  1.  0.  0.  0.  0. 13. 15. 10. 15. 5.  0....
第一个样本的标签: 0

函数小结

该函数的作用是简化数据加载过程,使主程序能够直接获得数字数据的特征和标签,而无需每次重新处理数据集。

No. 2 模型训练:


# 模块3: 模型训练与保存
def train_model(model_path, x_train, y_train):"""训练支持向量机分类器并保存"""classifier = svm.SVC(gamma=0.001)start = time.perf_counter()classifier.fit(x_train, y_train)print(f"训练完成, 耗时: {time.perf_counter() - start:.4f} 秒")with open(model_path, 'wb') as f:pickle.dump(classifier, f)print(f"模型已保存到 {model_path}")

函数模块作用:模型训练与保存

该函数 train_model 的主要作用是训练一个支持向量机(SVM)分类器,并将训练好的模型保存到指定路径,以便后续直接加载使用,而无需重复训练。

分类器指的是一个用来做“分类”任务的数学模型。通俗来说,它就像一个“判断器”或者“识别器”,它根据输入的数据,给出一个分类结果。

假设你有一堆手写数字的图片,每张图片上的数字可能是 0 到 9 之间的任何一个。分类器就是通过学习这些数字的特征(比如笔画的粗细、弯曲程度等),来判断每张图片上是什么数字。它的工作流程就像是:

  1. 学习:给分类器一些带标签的数字图片(比如手写的“3”标记为数字3,“7”标记为数字7)。
  2. 识别:在训练完之后,给分类器一个新的数字图片,分类器会根据它之前学到的知识,判断这张图片上的数字是几。

在代码中,这个分类器是通过 svm.SVC 创建的,这个算法使用“支持向量机”(SVM)来分类数据。它会根据训练数据中的数字图片特征,训练出一个模型,然后用这个模型来对新的、未见过的图片进行分类预测。


详细功能分解

  1. 训练分类器:

    • 使用 sklearn.svm.SVC 创建支持向量机分类器,并设置 gamma=0.001
      在支持向量机(SVM)中,gamma 是一个超参数,用于控制高斯径向基函数(RBF)核函数的影响范围。它的值决定了模型在决策边界上的灵活性和复杂度。
      • 较小的 gamma 值(例如 gamma=0.001):
        使得每个数据点对决策边界的影响范围更广,意味着模型的决策边界更加平滑和简单,可能导致欠拟合(underfitting)。

      • 较大的 gamma 值
        会使每个数据点的影响范围变小,决策边界更加复杂,容易过拟合(overfitting)训练数据。

        gamma=0.001 的作用:

        gamma=0.001 时,模型倾向于生成较为平滑的决策边界,对数据点的变化不那么敏感。这可能有助于避免过拟合,但如果数据中存在复杂的决策边界,可能导致模型无法很好地拟合数据(欠拟合)。因此,gamma 的选择需要通过交叉验证等方法来调优,以获得最佳的模型性能。

    • 调用 fit(x_train, y_train) 方法,用训练数据 x_train 和标签 y_train 对分类器进行训练。
  2. 计算训练时间:

    • 通过 time.perf_counter() 记录训练开始和结束时间,计算并输出训练耗时,方便了解模型训练效率。
  3. 保存模型:

    • 使用 pickle.dump 将训练好的分类器对象序列化,保存到文件中(路径由 model_path 指定)。
    • 保存后的模型文件可直接加载进行预测,无需每次运行程序时重新训练。
  4. 输出训练与保存状态:

    • 打印训练完成的耗时和模型保存路径,便于用户确认训练和保存是否成功。

模块作用小结

  • 主要目标: 完成模型的训练与保存工作。
  • 适用场景: 在数字识别等任务中,训练模型通常是一次性操作,通过保存模型文件,可以将训练阶段与预测阶段分离,提高系统运行效率。
  • 模块化好处: 提高代码复用性和可读性,用户可以更方便地替换训练数据或模型参数。

No.3 加载模型:

def load_model(model_path):"""加载保存的模型"""if not os.path.exists(model_path):raise FileNotFoundError(f"模型文件 {model_path} 不存在!请先训练模型。")with open(model_path, 'rb') as f:classifier = pickle.load(f)print(f"模型已从 {model_path} 加载")return classifier

上面这个模块就不多解释了 ,就是看模型(分类器)是否存在,不存在 就训练一个保存。

No.4 加载图像并预处理:

 好,现在开始加载我们的图像了,这里我准备的是一张100*100 的png图像,里面手写了一个2,需要注意的是  scikit-learn 自带共 1797条数据(图片),每条数据由64个特征点组成(8*8像素)

def preprocess_image(image_path):"""读取并预处理图像"""source = cv2.imread(image_path)if source is None:raise FileNotFoundError(f"文件 {image_path} 未找到或无法读取!")gray = cv2.cvtColor(source, cv2.COLOR_BGR2GRAY)gray = cv2.GaussianBlur(gray, (5, 5), 0)_, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)feature = cv2.resize(binary, (8, 8)).astype(float) / 16plt.imshow(feature, cmap='gray')plt.title("预处理后的图像")plt.show()return feature.flatten()

 函数作用:图像预处理

该函数 preprocess_image 的作用是从指定路径读取一张图像并进行一系列预处理操作,最终输出处理后的图像特征(以便用于机器学习模型的输入)。

这个函数会停在 显示一张图片,就是把我们前面的原始图进行了一些灰度处理,并描绘了一个轮廓,模型将参照这张处理后的图片标识去比对确认最终的数字

详细步骤:

  1. 读取图像:
    使用 cv2.imread(image_path) 从指定路径 image_path 读取图像文件。如果文件不存在或无法读取,会抛出 FileNotFoundError 异常。

  2. 转换为灰度图:
    cv2.cvtColor(source, cv2.COLOR_BGR2GRAY) 将读取的彩色图像转换为灰度图。这样可以简化图像处理,因为灰度图只有亮度信息,没有颜色信息。

  3. 高斯模糊:
    cv2.GaussianBlur(gray, (5, 5), 0) 对灰度图像应用高斯模糊。模糊操作有助于去除图像中的噪声,使后续的二值化更加平滑和稳定。

  4. 二值化:
    cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) 将灰度图像进行二值化操作。具体来说,图像中像素值大于 127 的部分变为 0(黑色),小于或等于 127 的部分变为 255(白色)。使用 THRESH_BINARY_INV 使得背景为白色,前景为黑色。

  5. 调整图像大小:
    cv2.resize(binary, (8, 8)) 将二值化后的图像调整为 8x8 的大小。这一步是为了将图像转换为固定大小的特征,使得每个图像都能统一输入到机器学习模型中。

  6. 缩放特征值:
    .astype(float) / 16 将图像数据类型转换为 float 类型,并将其值缩放到一个较小的范围(0 到 15),以便适应模型的输入需求。

  7. 显示预处理结果:
    plt.imshow(feature, cmap='gray') 显示预处理后的图像,并使用 plt.title("预处理后的图像") 给图像加上标题。

  8. 返回特征:
    feature.flatten() 将 8x8 的图像矩阵展平为一维数组(64个元素),作为模型的输入特征返回。

函数小结:

preprocess_image 函数的作用是读取图像并通过灰度转换、模糊处理、二值化和尺寸调整等一系列步骤,将图像转化为适合机器学习模型处理的特征向量。最终,输出一个展平的图像特征向量,以便进一步的分类或其他处理。

需要注意的是,这里面的参数是可调节的,有时候需要根据实际情况多次调试参数,以使得模型的识别更加准确。

No.5 图像预测:

def predict(classifier, x_test):"""使用分类器预测测试样本"""start = time.perf_counter()predictions = classifier.predict(x_test)print(f"预测完成, 耗时: {time.perf_counter() - start:.4f} 秒")return predictions

函数作用:使用分类器进行预测

该函数 predict 的作用是使用已经训练好的分类器对测试数据 x_test 进行预测,并返回预测结果。

详细步骤:

  1. 记录开始时间:
    start = time.perf_counter() 记录开始预测的时间。time.perf_counter() 返回一个高精度的时间戳,用于计算函数执行的时间。

  2. 进行预测:
    predictions = classifier.predict(x_test) 使用输入的 classifier(即训练好的分类器)对 x_test 进行预测。x_test 是待分类的测试样本,classifier.predict() 方法会返回对每个测试样本的预测结果(例如,分类标签或类别)。

  3. 计算并打印耗时:
    print(f"预测完成, 耗时: {time.perf_counter() - start:.4f} 秒") 计算从开始到完成预测的时间差,并以秒为单位打印出来。time.perf_counter() - start 得到的时间差值就是执行预测操作所花费的时间。

  4. 返回预测结果:
    return predictions 返回预测的结果。预测结果通常是一个数组或列表,其中包含对每个测试样本的预测分类标签。

函数小结:

predict 函数的作用是:给定一个训练好的分类器和一组待分类的测试数据,利用分类器对数据进行预测,并返回预测结果。同时,它会打印预测操作所花费的时间。

No.5 结果显示:

完整的代码如下:

import os
import time
import pickle
import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm# 模块1: 数据加载和预处理
def load_dataset():"""加载数字数据集"""digits = datasets.load_digits()return digits.data, digits.target# 模块2: 图像处理
def preprocess_image(image_path):"""读取并预处理图像"""source = cv2.imread(image_path)if source is None:raise FileNotFoundError(f"文件 {image_path} 未找到或无法读取!")gray = cv2.cvtColor(source, cv2.COLOR_BGR2GRAY)gray = cv2.GaussianBlur(gray, (5, 5), 0)_, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)feature = cv2.resize(binary, (8, 8)).astype(float) / 16plt.imshow(feature, cmap='gray')plt.title("预处理后的图像")plt.show()return feature.flatten()# 模块3: 模型训练与保存
def train_model(model_path, x_train, y_train):"""训练支持向量机分类器并保存"""classifier = svm.SVC(gamma=0.001)start = time.perf_counter()classifier.fit(x_train, y_train)print(f"训练完成, 耗时: {time.perf_counter() - start:.4f} 秒")with open(model_path, 'wb') as f:pickle.dump(classifier, f)print(f"模型已保存到 {model_path}")# 模块4: 加载模型并预测
def load_model(model_path):"""加载保存的模型"""if not os.path.exists(model_path):raise FileNotFoundError(f"模型文件 {model_path} 不存在!请先训练模型。")with open(model_path, 'rb') as f:classifier = pickle.load(f)print(f"模型已从 {model_path} 加载")return classifierdef predict(classifier, x_test):"""使用分类器预测测试样本"""start = time.perf_counter()predictions = classifier.predict(x_test)print(f"预测完成, 耗时: {time.perf_counter() - start:.4f} 秒")return predictions# 主函数: 流程控制
def main():# 数据加载与分割X, y = load_dataset()X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.1, random_state=42)model_path = "hand_write_classer.cfr"# 检查模型是否已存在if os.path.exists(model_path):classifier = load_model(model_path)else:print("未找到模型文件,开始训练新模型...")train_model(model_path, X_train, Y_train)classifier = load_model(model_path)# 图像处理与预测test_image_path = "num.png"feature_vector = preprocess_image(test_image_path)prediction = predict(classifier, [feature_vector])print(f"识别结果: {prediction}")if __name__ == "__main__":main()

这里的图像自己准备吧,用画图工具,画布尺寸100*100 ,再手写数字。

好了,今天的学习就到此结束啦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/484679.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker】创建Docker并部署Web站点

要在服务器上创建Docker容器,并在其中部署站点,你可以按照以下步骤操作。我们将以Flask应用为例来说明如何完成这一过程。 1. 准备工作 确保你的服务器已经安装了Docker。如果没有,请根据官方文档安装: Docker 安装指南 2. 创…

cgo内存泄漏排查

示例程序&#xff1a; package main/* #include <stdlib.h> #include <string.h> #include <stdio.h> char* cMalloc() {char *mem (char*)malloc(1024 * 1024 * 16);return mem; } void cMemset(char* mem) {memset(mem, -, 1024 * 1024 * 16); } int arr…

在做题中学习(76):颜色分类

解法&#xff1a;三指针 思路&#xff1a;用三个指针&#xff0c;把数组划分为三个区域&#xff1a; for循环遍历数组&#xff0c;i遍历数组&#xff0c;left是0区间的末尾&#xff0c;right是2区间的开头&#xff0c;0 1 2区间成功被划分 而上面的图画是最终实现的图样&…

性能测试基础知识jmeter使用

博客主页&#xff1a;花果山~程序猿-CSDN博客 文章分栏&#xff1a;测试_花果山~程序猿的博客-CSDN博客 关注我一起学习&#xff0c;一起进步&#xff0c;一起探索编程的无限可能吧&#xff01;让我们一起努力&#xff0c;一起成长&#xff01; 目录 性能指标 1. 并发数 (Con…

AWS创建ec2实例并连接成功

aws创建ec2实例并连接 aws创建ec2并连接 1.ec2创建前准备 首先创建一个VPC隔离云资源并且有公有子网 2.创建EC2实例 1.启动新实例或者创建实例 2.创建实例名 3.选择AMI使用linux(HVM) 4.选择实例类型 5.创建密钥对下载到本地并填入密钥对名称 6.选择自己创建的VPC和公有子网…

Flutter提示错误:无效的源发行版17

错误描述 Flutter从3.10.1 升级到3.19.4&#xff0c;在3.10.1的时候一切运行正常&#xff0c;但是当我将Flutter版本升级到3.19.4后&#xff0c;出现了下方的错误 FAILURE: Build failed with an exception.* What went wrong: Execution failed for task :device_info_plus:…

Android ConstraintLayout 约束布局的使用手册

目录 前言 一、ConstraintLayout基本介绍 二、ConstraintLayout使用步骤 1、引入库 2、基本使用&#xff0c;实现按钮居中。相对于父布局的约束。 3、A Button 居中展示&#xff0c;B Button展示在A Button正下方&#xff08;距离A 46dp&#xff09;。相对于兄弟控件的约束…

2025年申报建筑工程副高职称需要准备什么材料呢?

2025年湖北职称评审可以开始准备了&#xff0c;千万不要等到明临时报名开始才想起来准备哟&#xff0c;都是要提前的。 职称分为五个级别&#xff1a;技术员、初级职称、中级职称、副高职称、高级职称 我们最常见评审的就是中级职称和副高职称评审了&#xff0c;今天一起来看下…

“原批教育家”原批之星鲁健的杰作——原批俱乐部

伟大的原批教育家——原批之星&#xff0c;名为鲁健&#xff0c;是一位在南京邮电大学智能科学与技术专业中崭露头角的杰出人物。他不仅以其卓越的黑客技术和对网络正义的执着而闻名&#xff0c;更是“远古四神”之一&#xff0c;以其对原批之力的深刻理解和不同见解&#xff0…

底层逻辑之:欧拉-拉格朗日方程(Euler-Lagrange equations)变分法(Calculus of Variations)的核心思想

0前言&#xff1a; 0.1 17世纪的泛函&#xff08;Functional&#xff09;分析与变分法&#xff08;Calculus of Variations&#xff09; 在17世纪&#xff0c;数学家们开始遇到一些需要处理函数集合的问题&#xff0c;这些问题涉及到函数的极值、曲线的长度、曲面的面积等。这…

大数据实验E5HBase:安装配置,shell 命令和Java API使用

实验目的 熟悉HBase操作常用的shell 命令和Java API使用&#xff1b; 实验要求 掌握HBase的基本操作命令和函数接口的使用&#xff1b; 实验平台 操作系统&#xff1a;Linux&#xff08;建议Ubuntu16.04或者CentOS 7 以上&#xff09;&#xff1b;Hadoop版本&#xff1a;3…

微信小程序3-显标记信息和弹框

感谢阅读&#xff0c;初学小白&#xff0c;有错指正。 一、实现功能&#xff1a; 在地图上添加标记点后&#xff0c;标记点是可以携带以下基础信息的&#xff0c;如标题、id、经纬度等。但是对于开发来说&#xff0c;这些信息还不足够&#xff0c;而且还要做到点击标记点时&a…

一个有意思pytorch的简单应用小实验

通过一个简单的脚本&#xff0c;来学习pytorch的基本应用&#xff0c;比如&#xff1a;前向传播、反向传播、学习率以及预测、模型的基本原理和套路。 得到结果。。。保存模型。。。输入参数。。。预测。。。像不像&#xff1f;。。。像多少&#xff1f;。。。 设计目标&#x…

SpringBoot 分层解耦

从没有分层思想到传统 Web 分层&#xff0c;再到 Spring Boot 分层架构 1. 没有分层思想 在最初的项目开发中&#xff0c;很多开发者并没有明确的分层思想&#xff0c;所有逻辑都堆砌在一个类或一个方法中。这样的开发方式通常会导致以下问题&#xff1a; 代码混乱&#xff1…

2024 数学建模国一经验分享

2024 数学建模国一经验分享 背景&#xff1a;武汉某211&#xff0c;专业&#xff1a;计算机科学 心血来潮&#xff0c;就从学习和组队两个方面指点下后来者&#xff0c;帮新人避坑吧 2024年我在数学建模比赛中获得了国一&#xff08;教练说论文的分数是湖北省B组第一&#xff0…

Linux 35.6 + JetPack v5.1.4之RTP实时视频Python框架

Linux 35.6 JetPack v5.1.4之RTP实时视频Python框架 1. 源由2. 思路3. 方法论3.1 扩展思考 - 慎谋而后定3.2 扩展思考 - 拒绝拖延或犹豫3.3 扩展思考 - 哲学思考3.4 逻辑实操 - 方法论 4 准备5. 分析5.1 gst-launch-1.05.1.1 xvimagesink5.1.2 nv3dsink5.1.3 nv3dsink sync05…

渤海证券基于互联网环境的漏洞主动防护方案探索与实践

来源&#xff1a;中国金融电脑 作者&#xff1a;渤海证券股份有限公司信息技术总部 刘洋 伴随互联网业务的蓬勃发展&#xff0c;证券行业成为黑客进行网络攻击的重要目标之一&#xff0c;网络攻击的形式也变得愈发多样且复杂。网络攻击如同悬于行业之上的达摩克利斯之剑&…

隐私安全大考,Facebook 如何应对?

随着数字时代的到来和全球互联网用户的快速增长&#xff0c;隐私安全问题已上升为网络世界的重要议题。社交媒体巨头Facebook因其庞大的用户群体和大量的数据处理活动&#xff0c;成为隐私问题的聚焦点。面对隐私安全的大考&#xff0c;Facebook采取了一系列策略来应对这些挑战…

04 创建一个属于爬虫的主虚拟环境

文章目录 回顾conda常用指令创建一个爬虫虚拟主环境Win R 调出终端查看当前conda的虚拟环境创建 spider_base 的虚拟环境安装完成查看环境是否存在 为 pycharm 配置创建的爬虫主虚拟环境选一个盘符来存储之后学习所写的爬虫文件用 pycharm 打开创建的文件夹pycharm 配置解释器…

旅游管理系统的设计与实现

文末获取源码和万字论文&#xff0c;制作不易&#xff0c;感谢点赞支持。 毕 业 设 计&#xff08;论 文&#xff09; 题目&#xff1a;旅游管理系统的设计与实现 摘 要 如今社会上各行各业&#xff0c;都喜欢用自己行业的专属软件工作&#xff0c;互联网发展到这个时候&#…