TensorFlow学习:使用官方模型和自己的训练数据进行图片分类

前言

教程来源:清华大佬重讲机器视觉!TensorFlow+Opencv:深度学习机器视觉图像处理实战教程,物体检测/缺陷检测/图像识别

注:

这个教程与官网教程有些区别,教程里的api比较旧,核心思想是没有变化的。

上一篇文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调是基于官方案例来实现的分类,这次是从另一个角度来实现的分类。

基础知识

这部分基础知识之前没学过,这次正好根据视频教程简单学习一下。

Keras

简介
Keras是一个开源的深度学习框架,它是建立在Python之上的高级神经网络API。它提供了一个简单、直观的接口,使得构建、训练和部署深度学习模型变得更加容易。

TensorFlow 1.9 之后与Keras 进行了集成。在TensorFlow 中可以使用其API。

Keras相关模块

  • applications :Kears应用程序是具有预训练权重的固定架构
  • callback :在训练模型期间在某些点调用的实用程序
  • datasets :Keras 内置数据集
  • initializers :Keras初始化器,用于设置神经网络模型的权重和偏差的初始值。权重和偏差的初始值对模型的训练和收敛速度有很大的影响。
  • layers :Keras层API,layers模块提供了各种类型的层,用于搭建不同类型的神经网络架构。比如:Dense(全连接层)、Conv2D(卷积层)
  • losses:用于定义损失函数。损失函数是用来衡量模型的预测结果与真实标签之间的差异的指标。
  • metrics :用于定义评估指标,用于衡量模型的性能。比如根据准确率(accuracy)来评估模型性能
  • model :模型
  • optimizers :内置优化器
  • preprocessing:数据预处理工具
  • regularizers : 内置正规化器
  • utils :内置的一些工具类

构建神经网络模型

下面的代码是官方案例:https://tensorflow.google.cn/overview?hl=zh-cn

建议看一下视频教程里的神经网络介绍,会有一个更好的理解。

# 第一步,加载数据集、并进行归一化
mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 第二步,构建神经网络模型
model = tf.keras.models.Sequential([# 将输入的图像数据展平为一维数组tf.keras.layers.Flatten(input_shape=(28, 28)),# 创建一个有128个神经元和ReLU激活函数的全连接层,用于提取图像特征tf.keras.layers.Dense(128, activation='relu'),# 使用Dropout层,以防止过拟合tf.keras.layers.Dropout(0.2),# 最后一层是具有10个神经元和softmax激活函数的全连接层,用于输出分类的概率分布。10 是因为有10中分类类别tf.keras.layers.Dense(10, activation='softmax')
])
# 第三步,配置模型的优化器、损失函数和评估指标。
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 第四步,训练模型,训练5轮,在每一轮训练时会将所有数据进行分组,每一个组里有128张图片,批次最好是 2的次方,符合计算机2进制运算
model.fit(x_train, y_train, epochs=5, batch_size=128)
# 第五步,对模型进行测试,输出损失值、准确率
model.evaluate(x_test, y_test)

在这里插入图片描述

为什么使用relu激活函数
在构建神经网络模型时,选择激活函数通常是基于以下几个因素:

  • 非线性性质:激活函数的非线性性质是神经网络能够学习和表示复杂函数关系的关键。因为多个线性层的组合仍然是线性的,所以我们需要使用非线性函数来引入非线性变换 。常见的非线性激活函数包括ReLU(Rectified Linear Unit)、Sigmoid、Tanh等。

  • 梯度消失和梯度爆炸问题:在深层神经网络中,梯度的传播可能会出现梯度消失或梯度爆炸的问题。梯度消失指的是在反向传播过程中,梯度逐渐减小到接近零,导致底层的权重更新非常缓慢。梯度爆炸指的是梯度逐渐增大,导致底层的权重更新非常迅速。合适的激活函数可以缓解这些问题。例如,ReLU激活函数能够有效地抑制梯度消失和梯度爆炸。

  • 计算效率:激活函数的计算效率也是选择的一个因素。某些激活函数的计算比较简单,能够加速模型的训练和推理过程。

根据具体的任务和网络结构,选择合适的激活函数是一个实验性过程。在实践中,ReLU是最常用的激活函数,但也可以根据需求尝试其他的激活函数来提升模型性能。

为什么使用softmax激活函数

在构建分类模型时,常常使用softmax函数作为最后一层的激活函数。softmax函数将神经网络的输出转化为概率分布,用于多类别分类任务。

softmax函数将输入的向量转化为一个概率分布向量,其中每个元素表示对应类别的概率。具体地,对于输出层的每个神经元的输出值,softmax函数将其转化为一个在0到1之间的实数,且所有元素的和为1。这样做的好处是可以直接解释模型的输出结果,可以理解为每个类别的置信度或概率。

卷积神经网络

原理见:https://www.bilibili.com/video/BV1ee411K7WU?p=36&vd_source=fd72ff60b43cc949b3316d103871c31c

基本结构
卷积神经网络一般用于解决图片方面的问题。卷积神经网络主要有一下几个结构:

  • 卷积层:提取输入的不同特征
  • 池化层:减少图片的特征数量,避免全连接层参数过多
  • 全连接层:全连接层通常紧跟在卷积层和池化层之后,它将卷积层和池化层的输出进行扁平化,然后将其连接到一个或多个全连接层,最终输出预测结果。

卷积神经网络API

  • Conv2D:实现卷积
  • MaxPool2D:池化操作

例如:

# 设置卷积核为32,卷积核大小为5*5,卷积核步长为1,采用same填充方式,通道数放在最后,使用relu激活函数
tf.keras.layers.Conv2D(32, kernel_size=5, strides=1, padding='same',data_format='channels_last', activation='relu')
# 设置池化窗口为2*2,池化操作步长为2,采用same填充方式
tf.keras.layers.MaxPool2D(pool_size=2,strides=2,padding='same')

在卷积层中,在图像分类任务中,常见的kernel_size取值为3或5,而在物体检测任务中,通常会选择更大的kernel_size。通常建议使用奇数大小的kernel_size,可以保证中心对齐、避免边缘问题等

卷积层中,卷积核的数量是一个重要的超参数,会影响模型的性能和效果。通常情况下,卷积层中的卷积核数量会逐渐增加。一种常见的做法是从较少的卷积核数量开始,逐渐增加卷积核的数量,直到达到满足性能要求的水平。

在池化层中,pool_size参数表示池化窗口的大小。常见的pool_size取值包括2x2、3x3和4x4等

图片介绍

组成特征
组成一张图片的的特征值是所有的像素值,有三个维度:图片长度、图片宽度、图片通道数。

描述一个像素点,如果是灰度图,那么只需要一个数值来描述它,就是单通道。如果一个像素点,有RGB三种颜色来描述它,那就是三通道

  • 灰度图:单通道
  • 彩色图片:三通道

在TensorFlow中图片会用张量来表示

  • 单张图片:(高、宽、通道数)
  • 多张图片:(一个批次的图片数量,高、宽、通道数)

图片读取处理

读取图片

import tensorflow as tf # 加载图片,并加图片大小设置为224 * 224
image = tf.keras.preprocessing.image.load_img('./images/flower.jpg',target_size=(224,224))print("图片:",image)

不同的模型对输入的图片大小有不同的要求,需要调整图片大小使其符合模型的输入。
在这里插入图片描述
将图片转换为数组格式
读取的图片不能直接使用,需要将其转换成数组格式(张量)

# 转换成数组
img_arr = tf.keras.preprocessing.image.img_to_array(image)
print("图片形状:", img_arr)

在这里插入图片描述
有些模型还会对数组进行归一化,img_arr = img_arr / 255.0 。除以255是因为三原色值是0~255 。

注: img_to_array 有第二个参数为格式化方式,值是channels_first 或者 channels_last。即图片的通道数是在前面还是后面,不同框架可能会有不同的要求,TensorFlow默认为通道数在后。

图片形状
模型对图片的输入一般是三维或者四维的,可以进行查看或修改,以保证符合模型的要求

# 加载图片,并加图片大小设置为224 * 224
image = tf.keras.preprocessing.image.load_img('./images/flower.jpg', target_size=(224, 224))print("图片:", image)# 转换成数组
img_arr = tf.keras.preprocessing.image.img_to_array(image)print("图片形状:", img_arr.shape) # 三维 (224, 224, 3)# 有些模型需要四维模型,可以进行转换
new_img = img_arr.reshape(1,img_arr.shape[0],img_arr.shape[1],img_arr.shape[2])
print("四维:", new_img.shape)  # (1, 224, 224, 3)

在这里插入图片描述

图片分类

这里只简单介绍一下基于mobilenet_v2来进行迁移学习。在TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调 中介绍过一种方式,文章中的方式是来自于官方文档。

这里的方式是来源于视频教程:模型定义

训练模型

import tensorflow as tf
# matplotlib是用于绘制图表和可视化数据的库
import matplotlib.pylab as plt
import datetime# 加载内置的模型,include_top=False不使用默认的分类
base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False)# 冻结模型训练数据,冻结模型结构是为了保持预训练模型的权重不受训练的影响
# 训练数据少时只需要训练全连接层即可
for layer in base_model.layers:layer.trainable = False# 初始化类,并归一化
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255.0)
test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255.0)
# 读取训练集
train = train_generator.flow_from_directory(directory='data/train',  # 文件目录target_size=(224, 224),  # 处理图片大小,(h,w)batch_size=32,  # 批次数量class_mode='categorical' # 设置类别模式为,根据文件夹确定类别
)
# 读取验证集
test = test_generator.flow_from_directory(directory='data/validation', # 文件目录target_size=(224, 224),  # 处理图片大小,(h,w)batch_size=32,  # 批次数量class_mode='categorical' # 设置类别模式为,根据文件夹确定类别
)#print(train, test)
print(base_model.summary())
#print("输入:",base_model)# 微调模型
x = base_model.outputs[0]   # 移除分类后的模型输出
#print('x:', x)
# 输出到全连接层,加上全局池化
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# 添加一个有1024个神经元使用relu激活函数的全连接层
x = tf.keras.layers.Dense(1024, activation='relu')(x)
y_predict = tf.keras.layers.Dense(2, activation='softmax')(x)  # 全连接层,这里两个神经元是因为只有图片只有两类# 新模型
new_model = tf.keras.models.Model(inputs=base_model.inputs, outputs=y_predict)
print("新模型:",new_model)# 编译模型
new_model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])# 记录训练日志
log_dir = "logs/fit/" + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
# 用于在训练过程中收集模型指标和摘要数据,并将其写入TensorBoard日志文件中
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir= log_dir,histogram_freq=1
)
history = new_model.fit_generator(train,epochs=10,validation_data=test,callbacks=[tensorboard_callback])# 导出模型
export_path = 'tmp/cat_dog_model'
new_model.save(export_path)

这种方式需要按照固定的目录结果,如下
在这里插入图片描述
导出的模型
在这里插入图片描述
使用训练好的模型,进行预测

from matplotlib.font_manager import FontProperties
import tensorflow as tf
# matplotlib是用于绘制图表和可视化数据的库
import matplotlib.pylab as plt
import numpy as np#1、加载本地图片,并将其处理为224*224
image = tf.keras.preprocessing.image.load_img('./images/cat.png',target_size=(224,224))
# 2、转成数组
image = tf.keras.preprocessing.image.img_to_array(image)
print("图片形状:",image.shape)
# 3、扩展维度
image = image.reshape(1,image.shape[0],image.shape[1],image.shape[2])
# 4、处理输入,因为我们是基于mobilenet_v2训练的,因此可以使用mobilenet_v2处理图片
image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
# 5、加载模型
model = tf.keras.models.load_model('./tmp/cat_dog_model')
# 6、预测
predictions = model.predict(image)
index  = np.argmax(predictions,axis=1)[0]
label = ['猫','狗'][index]
print("预测结果:",predictions,index,label)
#7、可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(image[0]) # 显示图片
plt.xlabel(label[0],fontproperties=font)
plt.show() # 显示图形窗口

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/172351.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GoLong的学习之路(八)语法之Map

文章目录 Map初始化方式判断某个键是否存在map的遍历对value值遍历。对key值遍历 使用delete()函数删除键值对按照指定顺序遍历map元素为map的切片值为切片类型的map 做个题吧 Map 哈希表是一种巧妙并且实用的数据结构。它是一个无序的key/value对的集合,其中所有的…

Python环境下LaTeX数学公式转图像方案调研与探讨

目录 引言方案一:基于LaTeX环境方案二:基于KaTeX(推荐) 方案三:基于Matplotlib写在最后 引言 近来,涉及到一些公式识别的项目,输入是公式的图像,输出是LaTeX格式的数学公式字符串。 这类项目一般都采用深…

Powershell脚本自动备份dhcp数据库

文章目录 为什么要备份DHCP数据库呢?在PowerShell中自动备份DHCP数据库1,创建备份目录2,判断备份路径是否存在3,备份DHCP数据库4,完整自动备份脚本5,安排定期备份 推荐阅读 为什么要备份DHCP数据库呢&#…

国密 SM2 SSL 证书 Nginx 安装指南 linux版

一、获取国密证书 1、在您完成申请西部GDCA服务器证书的流程后,下载证书将获取一个证书包,有以下 *.***.com_sign.crt:签名证书 *.***.com_sign.key:签名证书私钥 *.***.com_encrypt.crt:加密证书 *.***.com_encr…

基于鸡群算法的无人机航迹规划-附代码

基于鸡群算法的无人机航迹规划 文章目录 基于鸡群算法的无人机航迹规划1.鸡群搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要:本文主要介绍利用鸡群算法来优化无人机航迹规划。 1.鸡群搜索算法 …

红队专题-从零开始VC++C/S远程控制软件RAT-MFC-远控介绍及界面编写

红队专题 招募六边形战士队员[1]远控介绍及界面编写1.远程控制软件演示及教程简要说明主程序可执行程序 服务端生成器主机上线服务端程序 和 服务文件管理CMD进程服务自启动主程序主对话框操作菜单列表框配置信息 多线程操作非模式对话框 2.环境:3.界面编程新建项目…

JavaScript_Pig Game切换当前玩家

const current0El document.getElementById(current--0); const current1El document.getElementById(current--1); if (dice ! 1) {currentScore dice;current0El.textContent currentScore;} else {} });这是我们上个文章写的代码,这个代码明显是有问题的&…

[量化投资-学习笔记003]Python+TDengine从零开始搭建量化分析平台-Grafana画K线图

在前面两个笔记: PythonTDengine从零开始搭建量化分析平台-数据存储 PythonTDengine从零开始搭建量化分析平台-MA均线的多种实现方式 中有提到使用 Grafana 画图,不过画的都是均线。除了均线,Grafana 非常人性的提供了 K线图模块 搭配 TDeng…

VScode 调试 linux内核

VScode 调试 linux内核 这里调试的 linux 内核是通过 LinuxSD卡(rootfs)运行的内核 gdb 命令行调试 编辑 /home/tyustli/.gdbinit 文件,参考 【GDB】 .gdbinit 文件 set auto-load safe-path /home/tyustli/code/open_source/kernel/linux-6.5.7/.gdbinit在 lin…

Cross Site Scripting (XSS)

攻击者会给网站发送可疑的脚本,可以获取浏览器保存的网站cookie, session tokens, 或者其他敏感的信息,甚至可以重写HTML页面的内容。 背景 XSS漏洞有不同类型,最开始发现的是存储型XSS和反射型XSS,2005,Am…

Linux中shell脚本中的运算

目录 一、运算符号 二、运算指令 三、练习 一、运算符号 加法-减法*乘法/除法%除法后的余数**乘方自加一--自减一 <小于<小于等于>大于>大于等于等于ji&#xff0c;jji*jj*i/jj/i%jj%i 二、运算指令 (()) ##((a12)) let ##let a12 expr ##expr 1 2 …

【数据结构】交换排序

⭐ 作者&#xff1a;小胡_不糊涂 &#x1f331; 作者主页&#xff1a;小胡_不糊涂的个人主页 &#x1f4c0; 收录专栏&#xff1a;浅谈数据结构 &#x1f496; 持续更文&#xff0c;关注博主少走弯路&#xff0c;谢谢大家支持 &#x1f496; 冒泡、快速排序 1. 冒泡排序2. 快速…

城市群(Megalopolis)/城际(inter-city)OD相关研究即Open Access数据集调研

文章目录 1 城市群/城际OD定义2 理论模型与分析方法2.1 重力模型 Gravity Model2.2 干预机会模型 Intervening Opportunities Model2.3 辐射模型 Radiation Model 3 Issues related to OD flows3.1 OD Prediction3.2 OD Forecasting3.3 OD Construction3.4 OD Estimation 4 OD …

基于单片机的智能电子鼻的设计

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 技术交流认准下方 CSDN 官方提供的联系方式 文章目录 概要 一、智能电子鼻系统的设计方案1.1智能电子鼻系统的设计思路1.2智能电子鼻系统的设计流程图1.3智能电子鼻系统的硬件数…

source insight4菜单工具按钮变乱恢复

目录 1&#xff1a;问题现象2&#xff1a;修改方式2.1 找到config_all.xml2.2 修改config_all.xml 1&#xff1a;问题现象 在source insight4点击工具按钮的时候&#xff0c;把工具全部都折叠了&#xff0c;然后手动拉出来的时候就乱了。 2&#xff1a;修改方式 2.1 找到con…

【多线程面试题 三】、 run()和start()有什么区别?

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a; run()和start()有什么区…

ffmpeg中examples编译报不兼容错误解决办法

ffmpeg中examples编译报不兼容错误解决办法 参考examples下的README可知&#xff0c;编译之前需要设置 PKG_CONFIG_PATH路径。 export PKG_CONFIG_PATH/home/user/work/ffmpeg/ffmpeg/_install_uclibc/lib/pkgconfig之后执行make出现如下错误&#xff1a; 基本都是由于库的版…

stm32的ADC采样率如何通过Time定时器进行控制

ADC采样率是个跟重要的概念. 手册上说可以通过Timer定时器进行触发ADC采样. 可我这边悲剧的是, 无论怎么样. ADC都会进行采样. 而且就算是TIM停掉也是一样会进行采样. 这就让我摸不着头脑了… 我想通过定时器动态更改ADC的采样频率. 结果不随我愿… 这到底是什么问题呢? 一…

哈希算法:如何防止数据库中的用户信息被脱库?

文章来源于极客时间前google工程师−王争专栏。 2011年CSDN“脱库”事件&#xff0c;CSDN网站被黑客攻击&#xff0c;超过600万用户的注册邮箱和密码明文被泄露&#xff0c;很多网友对CSDN明文保存用户密码行为产生了不满。如果你是CSDN的一名工程师&#xff0c;你会如何存储用…

uniapp实现webview页面关闭功能

实现思路&#xff1a; 1.关闭按钮是使用原生button添加的close属性。&#xff08;见page.json页面&#xff09; 2.监听关闭按钮的方法。&#xff08;onNavigationBarButtonTap&#xff09; 3.写实现关闭webview所有页面的逻辑。 废话不多说&#xff0c;直接上代码 1.page.…