5 Tensorflow图像识别(下)模型构建

上一篇:4 Tensorflow图像识别模型——数据预处理-CSDN博客

1、数据集标签

上一篇介绍了图像识别的数据预处理,下面是完整的代码:

import os
import tensorflow as tf# 获取训练集和验证集目录
train_dir = os.path.join('cats_and_dogs_filtered/train')
validation_dir = os.path.join('cats_and_dogs_filtered/validation')# 模型参数设置
BATCH_SIZE = 100# 图片尺寸统一为150*150
IMG_SHAPE = 150# 处理图像尺寸
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, )train_data_gen = img_generator.flow_from_directory(directory=train_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')
val_data_gen = img_generator.flow_from_directory(directory=validation_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')

上一篇提到系统的输入是“特征-标签”对,特征是输入的图片,标签就是标记该图片是猫还是狗。上面的代码如何知道输入的照片是猫还是狗?

这里用到了keras的一个函数flow_from_directory(),从目录中生成数据流,子目录会自动帮你生成标签。先看看train训练集的这两个子目录生成的标签是什么:

使用下面代码查看

print(train_data_gen.class_indices)

运行结果:

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
{'cats': 0, 'dogs': 1}

从运行结果可以看到,猫的照片系统自动打上了0的标签,狗的标签是1。

2、Relu激活函数

构建模型的完整代码如下:

import os
import tensorflow as tf
import numpy as np# 获取训练集和验证集目录
train_dir = os.path.join('cats_and_dogs_filtered/train')
validation_dir = os.path.join('cats_and_dogs_filtered/validation')# 模型参数设置
BATCH_SIZE = 100# 图片尺寸统一为150*150
IMG_SHAPE = 150# 处理图像尺寸
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, )# 训练数据
train_data_gen = img_generator.flow_from_directory(directory=train_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')# 验证数据
val_data_gen = img_generator.flow_from_directory(directory=validation_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(100, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dense(2)])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])EPOCHS = 20
history = model.fit_generator(train_data_gen,steps_per_epoch=int(np.ceil(2000 / float(BATCH_SIZE))),epochs=EPOCHS,validation_data=val_data_gen,validation_steps=int(np.ceil(1000 / float(BATCH_SIZE)))
)

model中加入了和之前不一样的代码:

tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),

这里使用了卷积神经,主要是为了突出区分不同对象的特征。一张图片的信息很多的,但往往我们只需要一些特征进行训练就可以了,后续会详细介绍。

现在先介绍 activation='relu',激活函数Relu。

ReLU,全称是线性整流函数(Rectified Linear Unit),是人工神经网络中常用的激活函数。它的图像如下:

当x<=0时,f(x)=0;

当x>0时,f(x)=x;

可以运行代码看看:

例1:

import tensorflow as tfx = -19
print(tf.nn.relu(x))

运行结果:
tf.Tensor(0, shape=(), dtype=int32)

输入-19,使用relu激活函数后的结果为0

例2:

import tensorflow as tfx = 8
print(tf.nn.relu(x))

运行结果:

tf.Tensor(8, shape=(), dtype=int32)

3、损失函数

代码:


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy']
              )
 

其中损失函数为SparseCategoricalCrossentropy,它是用于计算多分类问题的交叉熵,如果是两个或两个以上的分类问题可以始终这样设置。对其原理及计算过程的读者可以自行百度,此处不详细介绍。

4、训练过程详解
(1)训练准确率

运行上面的完整代码:

可以看到训练集和验证集的loss值在慢慢下降,准确率在提升。

划线部分是最后一个epoch的训练结果:

accuracy:0.8175,也就是说你的神经网络在分类训练数据方面的准确率约为82%;

val_accuracy:0.7160,在验证集的准确率约为72%

(2)batch_size批次大小

代码中batche_size设置的大小为100,意思是每批次生成的样本数量为100。

例如上述代码的train训练集一共有2000张图片,一个周期(epoch)分20个批次(2000/100=20)样本数据进行训练,每个批次训练完后利用优化器更新模型参数。

所以一个周期(epoch)的模型参数更新次数就是20:2000/batch_size=20

截图中红色部分,就是一个epoch分了20个批次用来更新模型参数。

训练结果会因为模型的参数的设置、训练集图片的数量等等原因结果大不相同,学习的时候可以自己动手去调整模型参数来看看训练结果。



 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/186266.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL Server SSIS ETL job执行相关操作

创建SSIS项目 Excel导入SQL Server 构建Excel源 配置Excel源信息 配置SQL Server目标 双击“ADO NET目标” job执行 新建job 右键“SQL Server代理”的“作业”&#xff0c;点击“新建作业”&#xff0c;弹出“新建作业”的选项页 首先是“常规”选项页&#xff0c;…

50代码审计-PHP无框架项目SQL注入挖掘

代码设计分为有框架和无框架 挖掘技巧&#xff1a;随机挖掘&#xff0c;定点挖掘&#xff0c;批量挖掘&#xff08;用工具帮助扫描探针&#xff0c;推荐工具&#xff1a;fortify&#xff0c;seay系统&#xff09;。 1.教学计划&#xff1a; ---审计项目漏洞 Demo->审计思…

卡尔曼滤波EKF

目录 一、概述 二、卡尔曼滤波的5个公式 三、应用案例&#xff1a;汽车运动 四、应用案例&#xff1a;温度估计 五、总结 一、概述 初学者对于卡尔曼滤波5个公式有点懵&#xff0c;本文先接地气地介绍5个公式&#xff0c;然后举两个常用例子加强理解&#xff0c;同时附有M…

微服务中配置文件(YAML文件)和项目依赖(POM文件)的区别与联系

实际上涉及到了微服务架构中的两个重要概念&#xff1a;服务间通信和项目依赖管理。在微服务架构中&#xff0c;一个项目可以通过两种方式与另一个项目建立依赖关系&#xff1a;通过配置文件&#xff08;如YAML文件&#xff09;和通过项目依赖&#xff08;如POM文件&#xff09…

Excel 转 Json 、Node.js实现(应用场景:i18n国际化)

创作灵感来源于在线转换是按照换行符去转换excel内容换行符后很难处理 本文是按单元格转换 const xlsx require(node-xlsx) const fs require(fs) const xlsxData xlsx.parse(./demo.xlsx) // 需要转换的excel文件// 数据处理 方便粘贴复制 const data xlsxData[2].data …

计算机考研408有多难?25考研经验贴,开个好头很有必要

前言 大家好&#xff0c;我是陈橘又青&#xff0c;相信关注我的各位小伙伴们中&#xff0c;大多都是在计算机专业的大学生吧&#xff01; 每天都有许多人在后台私信我&#xff0c;问我要不要考研&#xff0c;我想说这个东西是因人而异的&#xff0c;像我本人就选择了就业&…

vue分片上传视频并转换为m3u8文件并播放

开发环境&#xff1a; 基于若依开源框架的前后端分离版本的实践&#xff0c;后端java的springboot&#xff0c;前端若依的vue2&#xff0c;做一个分片上传视频并分段播放的功能&#xff0c;因为是小项目&#xff0c;并没有专门准备文件服务器和CDN服务&#xff0c;后端也是套用…

Unity | Shader(着色器)和material(材质)的关系

一、前言 在上一篇文章中 【精选】Unity | Shader基础知识&#xff08;什么是shader&#xff09;_unity shader_菌菌巧乐兹的博客-CSDN博客 我们讲了什么是shader&#xff0c;今天我们讲一下shder和material的关系 二、在unity中shader的本质 unity中&#xff0c;shader就…

RuoYi-Vue 在Swagger和Postman中 上传文件测试方案

RequestPart是Spring框架中用于处理multipart/form-data请求中单个部分的注解。在Spring MVC中&#xff0c;当处理文件上传或其他类型的多部分请求时&#xff0c;可以使用RequestPart注解将请求的特定部分绑定到方法参数上。 使用RequestPart注解时&#xff0c;需要指定要绑定…

一款功能强大的web目录扫描器专业版

dirpro 简介 dirpro 是一款由 python 编写的目录扫描器&#xff0c;操作简单&#xff0c;功能强大&#xff0c;高度自动化。 自动根据返回状态码和返回长度&#xff0c;对扫描结果进行二次整理和判断&#xff0c;准确性非常高。 已实现功能 可自定义扫描线程 导入url文件进…

SNP应邀参加2023中国企业数字化转型峰会暨赛意用户大会

创新驱动科技&#xff0c;数智驱动未来。如今&#xff0c;我国产业数字化进程提速升级&#xff0c;数字产业化规模持续壮大。数据显示&#xff0c;2022年&#xff0c;我国数字经济规模达50.2万亿元&#xff0c;总量稳居世界第二。数字经济已经成为推动传统产业转型升级、促进高…

php 二分查询算法实现

原理&#xff1a;二分查找算法&#xff08;Binary Search&#xff09;是一种针对有序数组的查找算法。它的原理是通过将查找区间逐渐缩小一半来快速定位要查找的目标值。 应用场景&#xff1a; 数据库或文件系统索引查找&#xff1a;在数据库或文件系统中&#xff0c;索引是有…

AJAX 入门笔记

课程地址 AJAX Asynchronous JavaScript and XML&#xff08;异步的 JavaScript 和 XML&#xff09; AJAX 不是新的编程语言&#xff0c;而是一种使用现有标准的新方法 AJAX 最大的优点是在不重新加载整个页面的情况下&#xff0c;可以与服务器交换数据并更新部分网页内容 XML…

C/C++特殊求和 2021年6月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C幻数求和 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C幻数求和 2021年6月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 如果一个数能够被7整除或者十进制表示中含有数字7&…

P02项目诊断报警组件(学习操作日志记录、单元测试开发)

★ P02项目诊断报警组件 诊断报警组件的主要功能有&#xff1a; 接收、记录硬件设备上报的报警信息。从预先设定的错误码对照表中找到对应的声光报警和蜂鸣器报警策略&#xff0c;结合当前的报警情况对设备下发报警指示。将报警消息发送到消息队列&#xff0c;由其它组件发送…

高级运维学习(十五)Zabbix监控(二)

一 Zabbix 报警机制 1 基本概念 自定义的监控项默认不会自动报警首页也不会提示错误需要配置触发器与报警动作才可以自动报警 2 概念介绍 &#xff08;1&#xff09;触发器 (trigger) 表达式&#xff0c;如内存不足300M&#xff0c;用户超过30个等 当触发条件发生后&a…

mac 卸载第三方输入法

输入法设置里的移除&#xff0c;并不是真的卸载&#xff0c;点击还是能添加回来 在活动监视器里强制退出此输入法在访达界面使用快捷键 ShiftcommandG在弹出的对话框内输入以下路径&#xff08;/资源库/Input Methods&#xff09;&#xff0c;再点击下面的前往找到你要卸载的输…

【第2章 Node.js基础】2.3 Node.js事件机制

2.3 Node.js事件机制 学习目标 &#xff08;1&#xff09;理解Node.js的事件机制&#xff1b; &#xff08;2&#xff09;掌握事件的监听与触发的用法。 文章目录 2.3 Node.js事件机制什么是事件机制为什么要有事件机制事件循环事件的监听与触发EventEmitter类常用API 什么是…

LabVIEW如何才能得到共享变量的引用

LabVIEW如何才能得到共享变量的引用 有一个LabVIEW 库文件 (.lvlib) &#xff0c;其中有一些定义好的共享变量。但需要得到每个共享变量的引用以便在程序运行时访问其属性。 共享变量的属性定义在“变量”类属性节点中。为了访问变量类&#xff0c;共享变量的引用必须连接到变…

Leetcode Hot 100之四:283. 移动零+11. 盛最多水的容器

283.移动零 题目&#xff1a; 给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元素的相对顺序。 请注意 &#xff0c;必须在不复制数组的情况下原地对数组进行操作。 示例 1: 输入: nums [0,1,0,3,12] 输出: [1,3,12,0,0] …