TensorFlow案例学习:对服装图像进行分类

前言

官方为我们提供了一个 对服装图像进行分类 的案例,方便我们快速学习

学习

预处理数据

案例中有下面这段代码

# 预处理数据,检查训练集中的第一个图像可以看到像素值处于0~255之间
plt.figure() # 创建图像窗口
plt.imshow(train_images[0]) # 显示图片
plt.colorbar()  # 在图像旁边添加颜色条
plt.grid(False) # 取消网格线
plt.show() # 显示图形窗口# 将值缩小至0~1之间,然后将其反馈到神经网络模型。训练集和测试集都需要处理
train_images = train_images / 255.0
test_images = test_images / 255.0

在这里插入图片描述

百度查了一下,将值缩小至0~1之间是为了

将训练集和测试集数据的值缩小到0~1之间是为了进行数据归一化(Normalization)。这是一个常见的预处理步骤,对于图像分类任务特别重要。
将图像的像素值缩放到0~1之间有几个好处:

  • 数值范围一致性:将所有像素值限制在0~1范围内可以确保不同样本的特征具有一致的数值区间。这有助于避免某些特征对模型训练产生过大的影响。
  • 梯度下降稳定性:在深度学习中,常用的优化算法如梯度下降依赖于权重的更新和损失函数的梯度计算。将像素值缩小到较小的范围可以使这些计算更加稳定,有助于加速模型的收敛。
  • 避免数值溢出:在一些激活函数和优化算法中,如果输入值太大,可能会导致数值溢出或不稳定的情况。将像素值限制在0~1之间可以减少这种情况的发生。

以后再遇见处理255时就明白这样做的目的了

构建模型

构建神经网络需要先配置模型的层,然后再编译模型。

设置层
神经网络的基本组成部分是层。层会从向其馈送的数据中提取表示形式。希望这些表示形式有助于解决手头上的问题。

大多数深度学习都包括将简单的层链接在一起。大多数层(如 tf.keras.layers.Dense)都具有在训练期间才会学习的参数

# 1、设置层
# tf.keras是TensorFlow中的高级API,用于构建和训练神经网络模型。它是一个基于Keras库的接口,提供了更简单、更高级的方式来定义、配置和训练神经网络模型。
# tf.keras.Sequential 用于按顺序堆叠各个神经网络层来构建模型,是一种简单的模型类型
model = tf.keras.Sequential([# 将图像格式从二维数组(28*28像素),转化为一维数组(28*28 = 784像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。tf.keras.layers.Flatten(input_shape=(28,28)), # 第二层,是一个具有128个神经元的全连接神经层tf.keras.layers.Dense(128,activation='relu'),# 第三层会返回一个长度为10的数组,每个都包含一个得分来表示当前图像属于10个类中的哪一个tf.keras.layers.Dense(10)
])

这段代码我相信很多人跟我一样都有些疑问,还好现在有gpt,不然都不知道上哪里去找答案。下面是我的一些疑问及gpt的回答:

  • 为什么只有三层。答:在神经网络中,层数的选择是一个灵活的设计选择,取决于特定问题的复杂性和数据集的特征。选择三层可能是为了简化模型或者问题本身不需要更多层
  • 第二层为什么是tf.keras.layers.Dense(128)。答:选择128个神经元是基于对问题复杂性的估计和经验。如果问题比较复杂或数据集较大,增加神经元数量可以增加模型的容量,提高模型的表示能力。
  • 第三层为什么是tf.keras.layers.Dense(10)。答:因为这是一个分类问题,这个案例中有10个分类。每个神经元对应一个类别,并输出相应类别的预测概率。
  • tf.keras.layers.Dense(128)是计算的来的吗。答:通常需要根据实际问题和数据集来进行调整。增加神经元的数量可以增加模型的容量和学习能力,但也可能导致过拟合。过拟合是指模型在训练数据上表现良好,但在新数据上表现较差。建议先从较小的数量开始,然后逐渐增加,直到模型的性能不再提高或开始出现过拟合为止。
  • 模型的最后一层是输出层吗。答:模型的最后一层通常是输出层。输出层的神经元数量通常与你要解决的问题相关。对于分类任务,输出层的神经元数量应该等于类别的数量。对于二分类任务,可以使用一个神经元来表示两个类别的概率。对于多分类任务,可以使用多个神经元,每个神经元表示一个类别的概率。在使用tf.keras``构建模型时,你可以使用tf.keras.layers.Dense`来定义输出层,并使用适当的激活函数来产生输出。

编译模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数 - 测量模型在训练期间的准确程度。你希望最小化此函数,以便将模型“引导”到正确的方向上。
  • 优化器 - 决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标 - 用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 2、编译模型
model.compile(optimizer='adam', # 指定优化器,adam是常用的优化器,可以自适应的调整学习率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 指定损失函数,这里使用了稀疏分类交叉熵损失函数metrics=['accuracy'] # 指定评估模型性能的指标,这里使用准确率
)

训练模型

训练神经网络模型需要执行以下步骤:

  • 将训练数据馈送给模型。在本例中,训练数据位于 train_images 和 train_labels 数组中。
  • 模型学习将图像和标签关联起来。
  • 要求模型对测试集(在本例中为 test_images 数组)进行预测。
  • 验证预测是否与 test_labels 数组中的标签相匹配。
# 1、将训练数据反馈给模型
# model.fit用于将模型与训练数据进行拟合,这里是将所有样本迭代10次
model.fit(train_images,train_labels,epochs=10)

如下图:
在这里插入图片描述

# 2、在测试数据集上评估准确率,verbose=2参数表示以详细模式输出评估过程
test_loss,test_acc = model.evaluate(test_images,test_labels,verbose=2)
print("损失率:",test_loss,"准确率:",test_acc)

如下图:
在这里插入图片描述

进行预测

# 进行预测
# 模型经过训练后,您可以使用它对一些图像进行预测。附加一个 Softmax 层,将模型的线性输出 logits 转换成更容易理解的概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
# 预测图片
predictions = probability_model.predict(test_images)print("第一个预测结果:",predictions[0])

预测结果是一个包含 10 个数字的数组。它们代表模型对 10 种不同服装中每种服装的“置信度”。您可以看到哪个标签的置信度值最大:

np.argmax(predictions[0])

使用训练好的模型

现在模型已经训练好了,我们可以基于模型对单个图像进行预测

# 使用训练好的模型
# 加载图片
img = Image.open('pics/shirt.png') 
# 调整大小
img = img.resize((28,28))
# 将彩色图片转为灰度图片
img_gray = img.convert('L')
# 将图像转换为 NumPy 数组,并反转颜色
img_arr = np.array(img_gray)
img_arr = 255 - img_arr
# 将图像像素值归一化到0~1
img_arr = img_arr / 255.0
# 将图像形状调整为(128288)
img_arr = img_arr.reshape(1,28,28)
# 可以保存处理后的文件,也可以进行预测
# np.save('abc.npy',img_arr)
# tf.keras 模型经过了优化,可同时对一个批或一组样本进行预测。因此,即便您只使用一个图像,您也需要将其添加到列表中
#img_arr = tf.keras.preprocessing.image.img_to_array(img)res = probability_model.predict(img_arr)
print("预测结果是:",res,class_names[np.argmax(res[0])])# 可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(img_arr[0]) # 显示图片
plt.xlabel(class_names[np.argmax(res[0])],fontproperties=font)
plt.show() # 显示图形窗口

这块是最复杂的,搞了好久才成功。你加载的图片是彩色的,你必须将图片变成灰度的,并且是28*28像素的图片,也就是你的图片要处理成符合这个模型的图片才行。

但是最终结果其实也不是很准确,根本原因是你的图片处理后,能够获取的特征就很少了,这样会导致判断错误。

结果
在这里插入图片描述

遇到的问题

问题1
在执行(train_images, train_labels), (test_images,test_labels) = fashion_mnist.load_data()时提示

Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz: None – [WinError 10054] 远程主机强迫关闭了 一个现有的连接。

这是加载数据集时失败了,国内访问下载谷歌的数据总会出现这样的问题。

解决:
1、打开数据集官方网站 https://github.com/zalandoresearch/fashion-mnist,将下面这4个数据下载到本地放到项目里

在这里插入图片描述
2、加载本地数据

import gzip
import numpy as npdef load_data():# 加载训练集图像数据with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载训练集标签数据with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:train_labels = np.frombuffer(f.read(), np.uint8, offset=8)# 加载测试集图像数据with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载测试集标签数据with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:test_labels = np.frombuffer(f.read(), np.uint8, offset=8)return (train_images, train_labels), (test_images, test_labels)# 调用加载数据函数
(train_images, train_labels), (test_images, test_labels) = load_data()

问题2
验证前25个图像,设置中文乱码。教程中的使用的是英文,我这里尝试了一下中文,中文乱码
在这里插入图片描述
解决:设置中文字体

# 字体属性
from matplotlib.font_manager import FontProperties# 验证训练集中的前25个图像,并显示其名称
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1) # 按照 5*5进行显示plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]],fontproperties=font)
plt.show()

在这里插入图片描述

完整代码

# 导入 TensorFlow 重命名
import tensorflow as tf# numpy是科学计算库,matplotlib是用于绘制图表和可视化数据的库
import numpy as np
import matplotlib.pylab as plt
# 字体属性
from matplotlib.font_manager import FontProperties# 用于加载文件
import gzip# 用于处理图片
from PIL import Image# 用于加载数据集的函数
def load_data():# 加载训练集图像数据with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载训练集标签数据with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:train_labels = np.frombuffer(f.read(), np.uint8, offset=8)# 加载测试集图像数据with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载测试集标签数据with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:test_labels = np.frombuffer(f.read(), np.uint8, offset=8)return (train_images, train_labels), (test_images, test_labels)print("tf版本:",tf.__version__)# 导入数据集,TensorFlow 内置的数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
# 将训练数据、测试数据取出,保存的元组里
(train_images, train_labels), (test_images,test_labels) = load_data()# 映射标签类,用于后面绘制图像使用
class_names = ['T恤/上衣', '裤子', '套头衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴']# 会打印出(60000, 28, 28),官方文档解释为训练集中有 60,000 个图像,每个图像由 28 x 28 的像素表示
print("训练数据集数据:",train_images.shape)# 预处理数据,检查训练集中的第一个图像可以看到像素值处于0~255之间
# plt.figure() # 创建图像窗口
# plt.imshow(train_images[0]) # 显示图片
# plt.colorbar()  # 在图像旁边添加颜色条
# plt.grid(False) # 取消网格线
# plt.show() # 显示图形窗口# 将值缩小至0~1之间,然后将其反馈到神经网络模型。训练集和测试集都需要处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 验证训练集中的前25个图像,并显示其名称
# font = FontProperties()
# font.set_family('Microsoft YaHei')
# plt.figure(figsize=(10,10))
# for i in range(25):
#     plt.subplot(5,5,i+1) # 按照 5*5进行显示
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(train_images[i], cmap=plt.cm.binary)
#     plt.xlabel(class_names[train_labels[i]],fontproperties=font)
# plt.show()# 构建模型# 1、设置层
# tf.keras是TensorFlow中的高级API,用于构建和训练神经网络模型。它是一个基于Keras库的接口,提供了更简单、更高级的方式来定义、配置和训练神经网络模型。
# tf.keras.Sequential 用于按顺序堆叠各个神经网络层来构建模型,是一种简单的模型类型
model = tf.keras.Sequential([# 将图像格式从二维数组(28*28像素),转化为一维数组(28*28 = 784像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。tf.keras.layers.Flatten(input_shape=(28,28)), # 第二层,是一个具有128个神经元的全连接神经层tf.keras.layers.Dense(128,activation='relu'),# 第三层会返回一个长度为10的数组,每个都包含一个得分来表示当前图像属于10个类中的哪一个tf.keras.layers.Dense(10)
])# 2、编译模型
model.compile(optimizer='adam', # 指定优化器,adam是常用的优化器,可以自适应的调整学习率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 指定损失函数,这里使用了稀疏分类交叉熵损失函数metrics=['accuracy'] # 指定评估模型性能的指标,这里使用准确率
)# 训练模型# 1、将训练数据反馈给模型
# model.fit用于将模型与训练数据进行拟合,这里是将所有样本迭代10次
model.fit(train_images,train_labels,epochs=10)# 2、在测试数据集上评估准确率,verbose=2参数表示以详细模式输出评估过程
test_loss,test_acc = model.evaluate(test_images,test_labels,verbose=2)
print("损失率:",test_loss,"准确率:",test_acc)# 进行预测
# 模型经过训练后,您可以使用它对一些图像进行预测。附加一个 Softmax 层,将模型的线性输出 logits 转换成更容易理解的概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
# 预测图片
predictions = probability_model.predict(test_images)print("第一个预测结果:",predictions[0],'类别是:',class_names[np.argmax(predictions[0])])# 使用训练好的模型
# 加载图片
img = Image.open('pics/shirt.png') 
# 调整大小
img = img.resize((28,28))
# 将彩色图片转为灰度图片
img_gray = img.convert('L')
# 将图像转换为 NumPy 数组,并反转颜色
img_arr = np.array(img_gray)
img_arr = 255 - img_arr
# 将图像像素值归一化到0~1
img_arr = img_arr / 255.0
# 将图像形状调整为(128288)
img_arr = img_arr.reshape(1,28,28)
# 可以保存处理后的文件,也可以进行预测
# np.save('abc.npy',img_arr)res = probability_model.predict(img_arr)
print("预测结果是:",res,class_names[np.argmax(res[0])])# 可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(img_arr[0]) # 显示图片
plt.xlabel(class_names[np.argmax(res[0])],fontproperties=font)
plt.show() # 显示图形窗口

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/152403.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RabbitMQ】初识消息队列 MQ,基于 Docker 部署 RabbitMQ,探索 RabbitMQ 基本使用,了解常见的消息类型

文章目录 前言一、初识消息队列 MQ1.1 同步通信1.2 异步通信1.3 MQ 常见框架及其对比 二、初识 RabbitMQ2.1 什么是 RabbitMQ2.2 RabbitMQ 的结构 三、基于 Docker 部署 RabbitMQ四、常见的消息类型五、示例:在 Java 代码中通过 RabbitMQ 发送消息5.1 消息发布者5.2…

软件测试「转行」答疑(未完更新中)

⭐ 专栏简介 软件测试行业「转行」答疑: 如果你对于互联网的职业了解一知半解!不知道行业的前景如何?对于众说纷纭的引流博主说法不知所措!不确定这个行业到底适不适合自己? 那么这一篇文章可以告诉你所有真实答案&a…

【网络安全】关于CTF那些事儿你都知道吗?

关于CTF那些事儿你都知道吗? 前言CTF那些事儿内容简介读者对象专家推荐 本文福利 前言 CTF比赛是快速提升网络安全实战技能的重要途径,已成为各个行业选拔网络安全人才的通用方法。但是,本书作者在从事CTF培训的过程中,发现存在几…

<el-input> textarea文本域显示滚动条(超过高度就自动显示)+ <el-input >不能正常输入,输入了也不能删除的问题

需求&#xff1a;首先是给定高度&#xff0c;输入文本框要自适应这个高度。文本超出高度就会显示滚动条否则不显示。 <el-row class"textarea-row"><el-col :span"3" class"first-row-title">天气</el-col><el-col :span&…

Selenium进行无界面爬虫开发

在网络爬虫开发中&#xff0c;利用Selenium进行无界面浏览器自动化是一种常见且强大的技术。无界面浏览器可以模拟真实用户的行为&#xff0c;解决动态加载页面和JavaScript渲染的问题&#xff0c;给爬虫带来了更大的便利。本文将为您介绍如何利用Selenium进行无界面浏览器自动…

如何绘制Top级美图?20+案例分享

如何绘制Top级美图&#xff1f;20案例分享 #R语言绘图128个 #图表美化47个 工欲善其事&#xff0c;必先利其器&#xff01; R语言绘图爱好者赶紧看过来&#xff01;画图时选择称手的R包&#xff0c;是高效绘制美图的First Step&#xff01;今天分享一波科研美图绘制所需R包…

TensorFlow入门(九、张量及操作函数介绍)

在TensorFlow程序中,所有的数据都由tensor数据结构来代表。即使在计算图中,操作间传递的数据也是Tensor tensor在TensorFlow中并不是直接采用数组的形式,它只是对TensorFlow中计算结果的引用。也就是说在张量中并没有真正保存数字,它保存的是如何得到这些数字的计算过程 一个…

WebSocket ----苍穹外卖day8

介绍 实现步骤 各个模块详解 OnOpen OnOpen:标记一个方法作为处理WebSocket连接打开的方法 当一个客户端与服务器建立 WebSocket 连接时&#xff0c;服务器会接收到一个连接请求。一旦服务器接受了这个连接请求&#xff0c;一个 WebSocket 连接就会被建立。这时&#xff0c;被…

Git仓库迁移记录

背景&#xff1a;gitlab私服上面&#xff0c;使用 import project的方式&#xff0c;从旧项目迁移到新地址仓库&#xff0c;但是代码一直没拉过去。所以使用命令的方式&#xff0c;进行代码迁移。 第一步&#xff1a;使用git clone --mirror git地址&#xff0c;进行代码克隆 …

如何让 Llama2、通义千问开源大语言模型快速跑在函数计算上?

作者&#xff1a;寒斜 阿里云智能技术专家 「本文是“在 Serverless 平台上构建 AIGC 应用”系列文章的第一篇文章。」 前言 随着 ChatGPT 以及 Stable Diffusion&#xff0c;Midjourney 这些新生代 AIGC 应用的兴起&#xff0c;围绕 AIGC 应用的相关开发变得越来越广泛&…

【一周安全资讯1007】多项信息安全国家标准10月1日起实施;GitLab发布紧急安全补丁修复高危漏洞

要闻速览 1.以下信息安全国家标准10月1日起实施 2.GitLab发布紧急安全补丁修复高危漏洞 3.主流显卡全中招&#xff01;GPU.zip侧信道攻击可泄漏敏感数据 4.MOVEit漏洞导致美国900所院校学生信息发生大规模泄露 5.法国太空和国防供应商Exail遭黑客攻击&#xff0c;泄露大量敏感…

三模块七电平级联H桥整流器电压平衡控制策略Simulink仿真

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

竞赛选题 深度学习 YOLO 实现车牌识别算法

文章目录 0 前言1 课题介绍2 算法简介2.1网络架构 3 数据准备4 模型训练5 实现效果5.1 图片识别效果5.2视频识别效果 6 部分关键代码7 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于yolov5的深度学习车牌识别系统实现 该项目较…

Linux虚拟机克隆之后使用ip addr无法获取ip地址

Linux虚拟机克隆之后使用ip addr无法获取ip地址 因为克隆得到的虚拟机&#xff0c;与原先的linux系统是一模一样的包括MAC地址和IP地址。需要修改信息。 设置IP地址&#xff1a; 使用vi命令打开linux的网卡 //ifcfg-enth0是虚拟网卡的名称&#xff0c;如果你的不叫这个名字&a…

[数据结构]迷宫问题求解

目录 数据结构——迷宫问题求解&#xff1a;&#xff1a; 1.迷宫问题 2.迷宫最短路径问题 数据结构——迷宫问题求解&#xff1a;&#xff1a; 1.迷宫问题 #include <stdio.h> #include <string.h> #include <stdlib.h> #include <assert.h> #includ…

拼多多API接口的使用方针如下:

了解拼多多API接口 拼多多API接口是拼多多网提供的一种应用程序接口&#xff0c;允许开发者通过程序访问拼多多网站的数据和功能。通过拼多多API接口&#xff0c;开发者可以开发各种应用程序&#xff0c;如店铺管理工具、数据分析工具、购物比价工具等。在本章中&#xff0c;我…

1.6 IntelliJ IDEA开发工具

前言&#xff1a; ### 1.6 IntelliJ IDEA开发工具笔记 - **背景**&#xff1a; - 使用基础文本编辑器如记事本编写Java代码虽然可行&#xff0c;但存在效率低下且难以调试的问题。 - 集成开发环境 (IDE) 可以有效地提高Java程序的开发效率。 - **常见Java IDE**&#xf…

基于springboot实现自习室预订系统的设计与实现项目【项目源码+论文说明】

基于springboot实现自习室预订系统的设计与实现演示 摘要 在网络高速发展的时代&#xff0c;众多的软件被开发出来&#xff0c;给学生带来了很大的选择余地&#xff0c;而且人们越来越追求更个性的需求。在这种时代背景下&#xff0c;学院只能以学生为导向&#xff0c;所以自习…

C# 通过winmm枚举音频设备

文章目录 前言一、如何实现&#xff1f;1、添加依赖&#xff08;1&#xff09;、nuget安装winmm的封装库&#xff08;2&#xff09;、补充接口2、定义实体3、实现枚举 二、完整代码三、使用示例总结 前言 使用C#做音频录制时需要获取音频设备信息&#xff0c;比如使用ffmpeg进…