基于Keras的手写数字识别(附源码)

目录

引言

为什么要创建虚拟环境,好处在哪里?

源码 

我修改的部分

调用本地数据

修改第二层卷积层


引言

本文是博主为了记录一个好的开源代码而写,下面是代码出处!强烈建议收藏!【深度学习实战—1】:基于Keras的手写数字识别(非常详细、代码开源)

写的非常好,但是复现这篇博客却让我吃了很多苦头, 大家要先下载Anaconda3然后创建一个虚拟环境,在虚拟环境里面主要下载以下三个东西版本号只要对应好,肯定能运行,其他的库少什么安装什么!如果用显卡跑模型,原博客有提及配置!

版本号
Python版本3.7.3
Keras版本2.4.3
tensorflow版本2.4.0

为什么要创建虚拟环境,好处在哪里?

在进行机器学习项目时,我们经常会遇到需要为不同的模型安装不同版本的Python或相关库的情况。这是因为每个模型可能依赖于特定版本的库,这些版本之间可能存在兼容性差异。如果不使用虚拟环境,而是在主环境中直接安装这些库,可能会遇到以下问题:

首先,当你为新的模型安装特定版本的库时,可能会覆盖掉主环境中已经存在的其他模型所需的库版本,导致之前的模型无法正常运行。

其次,不同的Python版本之间也可能存在兼容性问题。如果你直接在主环境中升级或降级Python版本,可能会影响到依赖于特定Python版本的其他项目。

为了避免这些问题,使用虚拟环境变得尤为重要。虚拟环境是一个隔离的Python环境,其中可以安装特定版本的Python和库,而不会影响到主环境或其他虚拟环境。这样,你可以为每个机器学习模型创建一个独立的虚拟环境,并在其中安装所需的Python版本和库版本,从而确保每个模型都能在其特定的环境中稳定运行。

通过这种方法,你可以轻松地管理多个项目,而无需担心库版本冲突或Python版本不兼容的问题。希望这样的解释能帮助大家更好地理解虚拟环境在机器学习项目中的重要性。

源码 

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras.datasets import mnist
from sklearn.metrics import confusion_matrix
import seaborn as sns
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras.utils import np_utils
import tensorflow as tfconfig = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)# 设定随机数种子,使得每个网络层的权重初始化一致
# np.random.seed(10)# x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
(x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()
# 假设你已经知道mnist.npz文件的路径
# file_path = 'mnist.npz'  # 替换为你的mnist.npz文件的实际路径
#
# # 加载npz文件
# with np.load(file_path, allow_pickle=True) as f:
#     x_train_original = f['x_train']
#     y_train_original = f['y_train']
#     x_test_original = f['x_test']
#     y_test_original = f['y_test']"""
数据可视化
"""# 单张图像可视化
def mnist_visualize_single(mode, idx):if mode == 0:plt.imshow(x_train_original[idx], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_train_original[idx])plt.title(title)plt.xticks([])  # 不显示x轴plt.yticks([])  # 不显示y轴plt.show()else:plt.imshow(x_test_original[idx], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_test_original[idx])plt.title(title)plt.xticks([])  # 不显示x轴plt.yticks([])  # 不显示y轴plt.show()# 多张图像可视化
def mnist_visualize_multiple(mode, start, end, length, width):if mode == 0:for i in range(start, end):plt.subplot(length, width, 1 + i)plt.imshow(x_train_original[i], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_train_original[i])plt.title(title)plt.xticks([])plt.yticks([])plt.show()else:for i in range(start, end):plt.subplot(length, width, 1 + i)plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_test_original[i])plt.title(title)plt.xticks([])plt.yticks([])plt.show()mnist_visualize_multiple(mode=0, start=0, end=4, length=2, width=2)
# 原始数据量可视化
print('训练集图像的尺寸:', x_train_original.shape)
print('训练集标签的尺寸:', y_train_original.shape)
print('测试集图像的尺寸:', x_test_original.shape)
print('测试集标签的尺寸:', y_test_original.shape)"""
数据预处理
"""
#
# 从训练集中分配验证集
x_val = x_train_original[50000:]
y_val = y_train_original[50000:]
x_train = x_train_original[:50000]
y_train = y_train_original[:50000]
print('======================')
# 打印验证集数据量
print('验证集图像的尺寸:', x_val.shape)
print('验证集标签的尺寸:', y_val.shape)
print('======================')
# 将图像转换为四维矩阵(nums,rows,cols,channels), 这里把数据从unint类型转化为float32类型, 提高训练精度。
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
x_val = x_val.reshape(x_val.shape[0], 28, 28, 1).astype('float32')
x_test = x_test_original.reshape(x_test_original.shape[0], 28, 28, 1).astype('float32')
#
# 原始图像的像素灰度值为0-255,为了提高模型的训练精度,通常将数值归一化映射到0-1。
x_train = x_train / 255
x_val = x_val / 255
x_test = x_test / 255
#
print('训练集传入网络的图像尺寸:', x_train.shape)
print('验证集传入网络的图像尺寸:', x_val.shape)
print('测试集传入网络的图像尺寸:', x_test.shape)
# #
# 图像标签一共有10个类别即0-9,这里将其转化为独热编码(One-hot)向量
y_train = np_utils.to_categorical(y_train)
print(y_train[0])y_val = np_utils.to_categorical(y_val)
y_test = np_utils.to_categorical(y_test_original)#
# """
# 定义网络模型
# """
#
#
def CNN_model():model = Sequential()model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层# model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层model.add(Flatten())  # 平铺层model.add(Dense(100, activation='relu'))  # 全连接层model.add(Dense(10, activation='softmax'))  # 全连接层print(model.summary())return model#
#
# """
# 训练网络
# """
#
model = CNN_model()
# #
# 编译网络(定义损失函数、优化器、评估指标)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])# 开始网络训练(定义训练数据与验证数据、定义训练代数,定义训练批大小) 原来20
train_history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=20, batch_size=32, verbose=2)# 模型保存
model.save('handwritten_numeral_recognition.h5')#
#
# #
# #
# 定义训练过程可视化函数(训练集损失、验证集损失、训练集精度、验证集精度)
def show_train_history(train_history, train, validation):plt.plot(train_history.history[train])plt.plot(train_history.history[validation])plt.title('Train History')plt.ylabel(train)plt.xlabel('Epoch')plt.legend(['train', 'validation'], loc='best')plt.show()show_train_history(train_history, 'accuracy', 'val_accuracy')
show_train_history(train_history, 'loss', 'val_loss')# 输出网络在测试集上的损失与精度
score = model.evaluate(x_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])# 测试集结果预测
predictions = model.predict(x_test)
predictions = np.argmax(predictions, axis=1)
print('前9张图片预测结果:', predictions[:9])# 预测结果图像可视化
def mnist_visualize_multiple_predict(start, end, length, width):for i in range(start, end):plt.subplot(length, width, 1 + i)plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))title_true = 'true=' + str(y_test_original[i])title_prediction = ',' + 'prediction' + str(model.predict_classes(np.expand_dims(x_test[i], axis=0)))title = title_true + title_predictionplt.title(title)plt.xticks([])plt.yticks([])plt.show()mnist_visualize_multiple_predict(start=0, end=9, length=3, width=3)# 混淆矩阵
cm = confusion_matrix(y_test_original, predictions)
cm = pd.DataFrame(cm)
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']def plot_confusion_matrix(cm):plt.figure(figsize=(10, 10))sns.heatmap(cm, cmap='Oranges', linecolor='black', linewidth=1, annot=True, fmt='', xticklabels=class_names,yticklabels=class_names)plt.xlabel("Predicted")plt.ylabel("Actual")plt.title("Confusion Matrix")plt.show()plot_confusion_matrix(cm)

我修改的部分

调用本地数据

# x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
# (x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()
# 假设你已经知道mnist.npz文件的路径
file_path = 'mnist.npz'  # 替换为你的mnist.npz文件的实际路径# 加载npz文件
with np.load(file_path, allow_pickle=True) as f:x_train_original = f['x_train']y_train_original = f['y_train']x_test_original = f['x_test']y_test_original = f['y_test']

因为原来的代码是每次运行都请求下载网上的在线数据,这是没必要的,当你运行了一次,可以把数据存在本地,然后以后本地调用

修改第二层卷积层

def CNN_model():model = Sequential()model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层# model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层model.add(Flatten())  # 平铺层model.add(Dense(100, activation='relu'))  # 全连接层model.add(Dense(10, activation='softmax'))  # 全连接层print(model.summary())return model

 原文中的第二层卷积层的输入是规定为(28,28,1),但是这是有问题的,应该是不设置参数,这样子的话,会自动将第一个池化层的输出当作输入

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/331491.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Web MVC(2)

响应 Http响应的结果可以是数据也可以是静态页面可以针对响应设置状态码 Header信息 返回静态页面注解RestController和Controller 我们创建一个前端页面 package com.example.demo.demos.web.controller;import org.springframework.web.bind.annotation.RequestMapping; i…

Rolla‘s homework:Image Processing with Python Final Project

对比学习Yolo 和 faster rcnn 两种目标检测 要求 Image Processing with Python Final Project Derek TanLoad several useful packages that are used in this notebook:Image Processing with Python Final Project Project Goals: • Gain an understanding of the object …

海康威视硬盘录像机NVR连接公网视频监控平台,注册失败,抓包发现有403 forbidden的问题解决

目录 一、问题描述 二、问题定位 1、查看DVR的配置 2、查看需要使用的端口是否开放 3、查看日志 4、抓包 (1)找出错误 (2)查看数据包内容 三、问题分析 1、国标28181中的域的概念 2、域应该如何定义 (1&am…

蓝桥杯备赛——DP【python】

一、小明的背包1 试题链接:https://www.lanqiao.cn/problems/1174/learning/ 问题描述 输入实例 5 20 1 6 2 5 3 8 5 15 3 3 输出示例 37 问题分析 这里我们要创建一个DP表,DP(i,j)表示处理到第i个物品时消耗j体…

STM32学习和实践笔记(30):窗口看门狗(WWDG)实验

1.WWDG介绍 1.1 WWDG简介 上一章我们已经介绍了IWDG,知道它的工作原理就是一个12位递减计数器不断递减计数,当减到0之前还未进行喂狗的话,产生一个MCU复位。 窗口看门狗WWDG其实和独立看门狗类似,它是一个7位递减计数器不断的往…

C语言之指针进阶(3),函数指针

目录 前言: 一、函数指针变量的概念 二、函数指针变量的创建 三、函数指针变量的使用 四、两段特殊代码的理解 五、typedef 六、函数指针数组 总结: 前言: 本文主要讲述C语言指针中的函数指针,包括函数指针变量的概念、创建…

aws msk加密方式和问控制连接方式

msk加密方式 msk提供了两种加密方式 静态加密传输中加密 创建集群时可以指定加密方式,参数如下 aws kafka create-cluster --cluster-name "ExampleClusterName" --broker-node-group-info file://brokernodegroupinfo.json --encryption-info file:/…

【基于springboot+vue的房屋租赁系统】

介绍 本系统是基于springbootvue的房屋租赁系统,数据库为mysql,可用于日常学习和毕设,系统分为管理员、房东、用户,部分截图如下所示: 部分界面截图 用户 管理员 联系我 微信:Zzllh_

Wpf 使用 Prism 实战开发Day24

自定义询问窗口 当需要关闭系统或进行删除数据或进行其他操作的时候&#xff0c;需要询问用户是否要执行对应的操作。那么就需要一个弹窗来给用户进行提示。 一.添加自定义询问窗口视图 (MsgView.xaml) 1.首先&#xff0c;添加一个自定义询问窗口视图 (MsgView.xaml) <Use…

qmt量化教程4----订阅全推数据

文章链接 qmt量化教程4----订阅全推数据 (qq.com) 上次写了订阅单股数据的教程 量化教程3---miniqmt当作第三方库设置&#xff0c;提供源代码 全推就主动推送&#xff0c;当行情有变化就会触发回调函数&#xff0c;推送实时数据&#xff0c;可以理解为数据驱动类型&#xff0…

使用 Flask 和 Celery 构建异步任务处理应用

文章目录 什么是 Flask&#xff1f;什么是 Celery&#xff1f;如何在 Flask 中使用 Celery&#xff1f;步骤 1&#xff1a;安装 Flask 和 Celery步骤 2&#xff1a;创建 Flask 应用程序步骤 3&#xff1a;运行 Celery Worker步骤 4&#xff1a;启动 Flask 应用程序 结论 在构建…

C# NetworkStream 流的详解与示例

文章目录 一、NetworkStream类的基本概念1.1 NetworkStream类概述1.2 NetworkStream类属性1.3 NetworkStream类方法 二、NetworkStream的连接方式三、NetworkStream的传输模式四、NetworkStream类示例服务器端代码&#xff1a;客户端代码&#xff1a; 五、总结 在C#中&#xff…

刷代码随想录有感(77):回溯算法——含有重复元素的全排列

题干&#xff1a; 代码&#xff1a; class Solution { public:vector<int> tmp;vector<vector<int>> res;void backtracking(vector<int> nums, vector<int> used){if(tmp.size() nums.size()){res.push_back(tmp);return;}sort(nums.begin(),…

iCloud 照片到 Android 指南:帮助您快速将照片从 iCloud 传输到安卓手机

​ 概括 iOS 和 Android 之间的传输是一个复杂的老问题。将 iCloud 照片传输到 Android 似乎是不可能的。放心。现在的高科技已经解决了这个问题。尽管 Apple 和 Android 不提供传输工具&#xff0c;但您仍然有其他有用的选项。这篇文章与您分享了 5 个技巧。因此&#xff0c;…

云部署最简单python web

最近在玩云主机&#xff0c;考虑将简单的web应用装上去&#xff0c;通过广域网访问一下&#xff0c;代码很简单&#xff0c;所以新手几乎不会碰到什么问题。 from flask import Flaskapp Flask(__name__)app.route(/) def hello_world():return Hello, World!app.route(/gree…

plsql 学习

过程化编程语言 赋值&#xff1a;&#xff1a; ||&#xff1a;连接符号 dbms_output.put_line() :输出的语句 var_name ACCOUNTLIBRARY.USERNAME%type; 变量名&#xff1b;某个表的数据类型&#xff1b;赋值给变量名 用下面的方法更好用 异常exception 循…

Linux网络编程:HTTP协议

前言&#xff1a; 我们知道OSI模型上层分为应用层、会话层和表示层&#xff0c;我们接下来要讲的是主流的应用层协议HTTP&#xff0c;为什么需要这个协议呢&#xff0c;因为在应用层由于操作系统的不同、开发人员使用的语言类型不同&#xff0c;当我们在传输结构化数据时&…

算法打卡 Day9(字符串KMP 算法)-实现 strStr+ 重复的子字符串

KMP 算法 KMP 算法解决的是字符串匹配的问题&#xff0c;其经典思想是&#xff1a;当出现的字符串不匹配时&#xff0c;可以记录一部分之前已经匹配的文本内容&#xff0c;利用这些信息避免从头再去做匹配。 前缀表 next 数组就是一个前缀表。前缀表是用来回退的&#xff0c…

【启明智显技术分享】SOM2D02-2GW核心板适配ALSA(适用Sigmastar ssd201/202D)

提示&#xff1a;作为Espressif&#xff08;乐鑫科技&#xff09;大中华区合作伙伴及sigmastar&#xff08;厦门星宸&#xff09;VAD合作伙伴&#xff0c;我们不仅用心整理了你在开发过程中可能会遇到的问题以及快速上手的简明教程供开发小伙伴参考。同时也用心整理了乐鑫及星宸…

TypeScript学习日志-第三十二天(infer关键字)

infer关键字 一、作用与使用 infer 的作用就是推导泛型参数&#xff0c;infer 声明只能出现在 extends 子语句中&#xff0c;使用如下&#xff1a; 可以看出 已经推导出类型是 User 了 二、协变 infer 的 协变会返回联合类型&#xff0c;如图&#xff1a; 三、逆变 infer…