使用 TensorFlow 实现 ZFNet 进行 MNIST 图像分类

        ZFNet(ZF-Net)是由 Matthew Zeiler 和 Rob Fergus 提出的卷积神经网络架构,它在图像分类任务中取得了显著的效果。它在标准卷积神经网络(CNN)的基础上做了一些创新,例如优化了卷积核大小和池化策略,使得网络在处理图像时表现得更加高效。

        本文将详细介绍如何使用 TensorFlow 2.x 实现 ZFNet,在 MNIST 数据集上进行图像分类,并将训练部分和测试部分分开进行讲解。

1. 环境准备

        首先,我们需要确保已安装 TensorFlow 和其他相关库。在命令行中执行以下命令进行安装:

pip install tensorflow matplotlib numpy

2. 训练部分:构建和训练 ZFNet 模型

        在训练部分,我们将加载 MNIST 数据集,构建 ZFNet 模型,并在 GPU 或 CPU 上进行训练。

2.1 加载并预处理 MNIST 数据集

        MNIST 数据集包含了 70,000 张手写数字图像,训练集包含 60,000 张,测试集包含 10,000 张。在加载数据后,我们需要对数据进行预处理:标准化和调整大小。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from zfnet import create_zfnet_model  # 从 zfnet.py 导入模型创建函数def prepare_data():"""准备 MNIST 数据集并进行预处理:return: 训练集和测试集的图像及标签"""# 加载 MNIST 数据集(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理:标准化、调整大小、添加维度x_train = x_train.astype('float32') / 255.0x_test = x_test.astype('float32') / 255.0# 调整图像大小并添加额外维度 (32x32, 1通道)x_train = tf.image.resize(x_train[..., tf.newaxis], (32, 32))x_test = tf.image.resize(x_test[..., tf.newaxis], (32, 32))# 确保数据类型是 float32x_train = tf.cast(x_train, tf.float32)x_test = tf.cast(x_test, tf.float32)# 类别标签 one-hot 编码y_train = to_categorical(y_train, 10)y_test = to_categorical(y_test, 10)return x_train, y_train, x_test, y_test

解释:
  1. 标准化:图像像素值从 [0, 255] 转换为 [0, 1],有助于加速网络训练并提高稳定性。
  2. 调整图像大小:由于 ZFNet 网络需要 32x32 的输入图像,所以我们将图像大小调整为 32x32。
  3. One-Hot 编码:标签数据转换为 One-Hot 编码格式,以便与神经网络输出匹配。

2.2 创建 ZFNet 模型

        ZFNet 是一个深度卷积神经网络,它的设计关注如何高效地提取图像特征。我们通过以下代码来构建 ZFNet 模型。

from tensorflow.keras import layers, modelsdef create_zfnet_model(input_shape=(32, 32, 1), num_classes=10):"""创建 ZFNet 模型。参数:- input_shape: 输入图像的形状,默认 (32, 32, 1)。- num_classes: 类别数目,默认 10。返回:- 返回构建好的模型。"""model = models.Sequential()# 使用 Input 层显式定义输入形状model.add(layers.Input(shape=input_shape))  # 显式指定输入形状# 特征提取部分model.add(layers.Conv2D(64, (7, 7), activation='relu', strides=2, padding='same'))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same'))model.add(layers.Conv2D(128, (5, 5), activation='relu', padding='same'))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same'))model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))model.add(layers.Conv2D(512, (3, 3), activation='relu', padding='same'))# 扁平化层model.add(layers.Flatten())# 全连接层model.add(layers.Dense(1024, activation='relu'))model.add(layers.Dropout(0.5))# 输出层model.add(layers.Dense(num_classes, activation='softmax'))return model

解释:
  • 卷积层:通过多个卷积层提取图像的空间特征。ZFNet 采用不同大小的卷积核(如 7x7、5x5 和 3x3),通过优化的卷积结构捕捉更多层次的图像信息。
  • 池化层:最大池化层用于减少图像尺寸,并使特征保持重要信息。
  • 全连接层:通过扁平化和全连接层进一步处理特征,并输出分类结果。

2.3 编译与训练模型

        在训练之前,我们需要编译模型并选择优化器和损失函数。然后,调用 fit 函数开始训练。

def compile_model(model):"""编译模型:param model: 待编译的模型:return: 已编译的模型"""model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])return modeldef train_model(model, x_train, y_train, x_test, y_test, device, epochs=5, batch_size=128):"""在指定设备上训练模型:param model: 训练的模型:param x_train: 训练集图像:param y_train: 训练集标签:param x_test: 测试集图像:param y_test: 测试集标签:param device: 设备:param epochs: 训练轮数:param batch_size: 批处理大小"""with tf.device(device):model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, y_test))

解释:
  • 优化器:我们使用 Adam 优化器,它具有自适应学习率,非常适合深度学习任务。
  • 损失函数categorical_crossentropy 用于多分类问题。
  • 训练:通过 model.fit() 函数训练模型,并在每个 epoch 后使用测试数据进行验证。

3. 测试部分:评估模型并进行预测

        一旦训练完成,我们将评估模型在测试集上的表现,并可视化其预测结果。

3.1 评估模型

def evaluate_model(model, x_test, y_test):"""评估模型在测试集上的表现:param model: 训练好的模型:param x_test: 测试集图像:param y_test: 测试集标签:return: 测试集上的损失和准确率"""test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)print(f"Test accuracy: {test_acc}")return test_loss, test_acc

解释:
  • 使用 evaluate() 方法评估模型的性能,返回模型的损失和准确率。

3.2 可视化预测结果

def visualize_predictions(model, x_test, y_test, num_images=6):"""可视化模型对多张测试图片的预测结果:param model: 训练好的模型:param x_test: 测试集图像:param y_test: 测试集标签:param num_images: 显示图像的数量"""predictions = model.predict(x_test[:num_images])predicted_labels = np.argmax(predictions, axis=1)actual_labels = np.argmax(y_test[:num_images], axis=1)# 绘制结果fig, axes = plt.subplots(2, 3, figsize=(10, 7))axes = axes.ravel()for i in range(num_images):ax = axes[i]# 将 Tensor 转换为 NumPy 数组,并使用 reshapeimg = x_test[i].numpy().reshape(32, 32)  # 这里调用 .numpy() 将 Tensor 转换为 NumPy 数组ax.imshow(img, cmap='gray')ax.set_title(f"Pred: {predicted_labels[i]} | Actual: {actual_labels[i]}")ax.axis('off')plt.tight_layout()plt.show()

解释:
  • 预测结果可视化:我们选择部分图像进行预测并显示模型的预测标签和真实标签,帮助分析模型的分类效果。

3.3 计算整体准确率

# 计算整体准确率accuracy = np.sum(predicted_labels == actual_labels) / len(actual_labels)print(f"Accuracy on the entire test set: {accuracy * 100:.2f}%")

解释:
  • 通过对比预测标签和实际标签,计算模型在测试集上的整体准确率。

4. 总结

        本文介绍了如何使用 TensorFlow 实现 ZFNet 网络,并在 MNIST 数据集上进行训练和测试。通过训练模型、评估性能、可视化预测结果,我们能够更好地理解 ZFNet 的优势和图像分类中的应用。

        希望这篇博客能帮助你掌握 ZFNet 的实现过程,理解其背后的原理,并能够顺利地应用到其他图像分类任务中!

        如有问题或进一步的疑问,请随时留言讨论!

完整项目:

https://github.com/qxd-ljy/ZFNet-TensorFlowicon-default.png?t=O83Ahttps://github.com/qxd-ljy/ZFNet-TensorFlowZFNet-TensorFlow: 使用 TensorFlow 实现 ZFNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/zfnet-tensor-flow

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/473269.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何让手机ip变成动态

在数字化浪潮中,手机已成为我们日常生活中不可或缺的一部分。无论是浏览网页、使用社交媒体还是进行在线购物,手机都扮演着举足轻重的角色。然而,在享受网络带来的便利时,我们也需要关注网络安全和隐私保护。静态IP地址可能让手机…

前端三大组件之CSS,三大选择器,游戏网页仿写

回顾 full stack全栈 Web前端三大组件 结构(html) 样式(css) 动作/交互(js) --- 》 框架vue&#xff0c;安哥拉 div 常用的标签 扩展标签 列表 ul/ol order——有序号 unordered——没序号的黑点 <!DOCTYPE html> <html><head><meta charset"…

CPU执行指令的过程

通过前面两篇文章的介绍&#xff0c;我们已经认识到了&#xff1a;可执行程序通过作业调度装入内存&#xff0c;操作系统为进程创建虚拟地址空间&#xff0c;分配物理内存&#xff0c;建立页表&#xff08;映射关系&#xff09;&#xff0c;申请并初始化PCB&#xff0c;开始调度…

【MySQL】InnoDB内存结构

目录 InnoDB内存结构 主要组成 缓冲池 缓冲池的作用 缓冲池的结构 缓冲池中页与页之间连接方式分析 缓冲池如何组织数据 控制块初始化 页面初始化 缓冲池中页的管理 缓冲区淘汰策略 查看缓冲池信息 总结 变更缓冲区-Chang Buffer 变更缓冲区的作用 主要配置选项…

论文笔记 SuDORMRF:EFFICIENT NETWORKS FOR UNIVERSAL AUDIO SOURCE SEPARATION

SUDORMRF: EFFICIENT NETWORKS FOR UNIVERSAL AUDIO SOURCE SEPARATION 人的精神寄托可以是音乐&#xff0c;可以是书籍&#xff0c;可以是运动&#xff0c;可以是工作&#xff0c;可以是山川湖海&#xff0c;唯独不可以是人。 Depthwise Separable Convolution 深度分离卷积&a…

SpringBoot+React养老院管理系统 附带详细运行指导视频

文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.入住合同文件上传2.添加和修改套餐的代码3.查看入住记录代码 一、项目演示 项目演示地址&#xff1a; 视频地址 二、项目介绍 项目描述&#xff1a;这是一个基于SpringBootReact框架开发的养老院管理系统。首先…

w039基于Web足球青训俱乐部管理后台系统开发

&#x1f64a;作者简介&#xff1a;多年一线开发工作经验&#xff0c;原创团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339;赠送计算机毕业设计600个选题excel文…

Go 语言已立足主流,编程语言排行榜24 年 11 月

Go语言概述 Go语言&#xff0c;简称Golang&#xff0c;是由Google的Robert Griesemer、Rob Pike和Ken Thompson在2007年设计&#xff0c;并于2009年11月正式宣布推出的静态类型、编译型开源编程语言。Go语言以其提高编程效率、软件构建速度和运行时性能的设计目标&#xff0c;…

【java】链表:判断链表是否成环

问题&#xff1a; 分析&#xff1a; 这里我们还是定义快慢双指针 。 如果有环&#xff0c;快慢指针一定会相遇。 // 构建成环链表public void makeCircle(){Node node1new Node(1);Node node2new Node(2);Node node3new Node(5);Node node4new Node(6);Node node5new …

基于视觉智能的时间序列基础模型

GitHub链接&#xff1a;ViTime: A Visual Intelligence-Based Foundation Model for Time Series Forecasting 论文链接&#xff1a;https://github.com/IkeYang/ViTime 前言 作者是来自西安理工大学&#xff0c;西北工业大学&#xff0c;以色列理工大学以及香港城市大学的研…

006.精读《Apache Paimon Docs - Concepts》

文章目录 1. 引言2. 基本概念2.1 基本构成2.2 Schema2.3 Snapshot2.4 Manifest2.5 Data File2.6 Table2.7 File index 3.并发控制3.1 基本概念3.2 快照冲突3.3 文件冲突 4. 总结 1. 引言 在本期的技术深度解析中&#xff0c;我们将学习并且了解Apache Paimon 的基本概念&#…

RedHat7—Linux中kickstart自动安装脚本制作

本实验使用虚拟机版本为rhel7&#xff0c;从rhel7后的版本kickstart工具进行收费使用。 1.在VMware关闭dhcp自动获取ip地址功能 2.安装并启动httpd [rootlocalhost ~]# yum install httpd [rootlocalhost ~]# systemctl start httpd [rootlocalhost ~]#systemctl stop firewal…

数据集的重要性:如何构建AIGC训练集

文章目录 一、为什么数据集对AIGC如此重要&#xff1f;1. 数据决定模型的知识边界2. 数据质量直接影响生成效果3. 数据集多样性提升模型鲁棒性 二、构建AIGC训练集的关键步骤1. 明确目标任务和生成需求2. 数据源的选择3. 数据清洗与预处理4. 数据标注5. 数据增强 三、针对不同类…

结构化需求分析与设计

前言: 感觉书本上和线上课程, 讲的太抽象, 不好理解, 但软件开发不就是为了开发应用程序吗?! 干嘛搞这么抽象,对吧, 下面是个人对于软件开发的看法, 结合我的一些看法, 主打简单易懂, 当然,我一IT界小菜鸟, 对软件开发的认识也很浅显, 这个思维导图也仅仅是现阶段我的看…

docker-hub 无法访问,使用windows魔法拉取docker images再上传到linux docker环境中

云机的服务器是可以docker拉取镜像的&#xff0c;但是本地的虚拟机、物理服务器等网络环境不好的情况&#xff0c;是无法访问docker-hub的&#xff0c;即使更换了docker镜像源国内源也无法使用。 本文章使用 在魔法网络环境下的windows&#xff0c;下载docker images后&#xf…

LlamaIndex+本地部署InternLM实践

1.环境配置 1.1 配置基础环境 这里以在 Intern Studio 服务器上部署 LlamaIndex 为例。 首先&#xff0c;打开 Intern Studio 界面&#xff0c;点击 创建开发机 配置开发机系统 填写 开发机名称 后&#xff0c;点击 选择镜像 使用 Cuda11.7-conda 镜像&#xff0c;然后在资源…

MySql 日期周处理方式

MySql 日期周处理方式 最近在做数仓相关工作&#xff0c;最近遇到 几个问题&#xff0c; 1、计算指定日期是一年中的第几周&#xff0c;周一为周的第一天 2、计算周的开始时间&#xff0c;结束时间 3、计算周对应的年 比如 2023-01-01 WEEKOFYEAR(2023-01-01) 是2022年的52周&…

AI驱动的桌面笔记应用Reor

网友 竹林风 说&#xff0c;已经成功的用 mxbai-embed-large 映射到 text-embedding-ada-002&#xff0c;并测试成功了。不愧是爱折腾的人&#xff0c;老苏还没时间试&#xff0c;因为又找到了另一个支持 AI 的桌面版笔记 Reor Reor 简介 什么是 Reor ? Reor 是一款由人工智…

每日一博 - Java的Shallow Copy和Deep Copy

文章目录 概述创建对象的5种方式1. 通过new关键字2. 通过Class类的newInstance()方法3. 通过Constructor类的newInstance方法4. 利用Clone方法5. 反序列化 Clone方法基本类型和引用类型浅拷贝深拷贝如何实现深拷贝1. 让每个引用类型属性内部都重写clone()方法2. 利用序列化 概述…

Rewar Model的输出(不包含训练)

这里写自定义目录标题 介绍模型推理的输出过程方案原始Token输出RM输出&#xff08;回归任务&#xff09; 介绍 奖励函数模型 (Reward Model) 是人工智能 (AI) 中的一种方法&#xff0c;模型因其对给定提示的响应而获得奖励或分数。现在的文章清一色的讲解RM的训练&#xff0c…