T10 数据增强

文章目录

  • 一、准备环境和数据
    • 1.环境
    • 2. 数据
  • 二、数据增强(增加数据集中样本的多样性)
  • 三、将增强后的数据添加到模型中
  • 四、开始训练
  • 五、自定义增强函数
  • 六、一些增强函数

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第10周:数据增强(训练营内部成员可读)
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

本文说明了两种数据增强方式,以及如何自定义数据增强方式并将其放到我们代码当中,两种数据增强方式如下:
● 将数据增强模块嵌入model中
● 在Dataset数据集中进行数据增强

常用的tf增强函数在文末有说明

一、准备环境和数据

1.环境

import matplotlib.pyplot as plt
import numpy as np
import sys
from datetime import datetime
#隐藏警告
import warnings
warnings.filterwarnings('ignore')from tensorflow.keras import layers
import tensorflow as tfprint("--------# 使用环境说明---------")
print("Today: ", datetime.today())
print("Python: " + sys.version)
print("Tensorflow: ", tf.__version__)gpus = tf.config.list_physical_devices("GPU")
if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用print(gpus)
else:print("Use CPU")

在这里插入图片描述

2. 数据

使用上一课的数据集,即猫狗识别2的数据集。其次,原数据集中不包括测试集,所以使用tf.data.experimental.cardinality确定验证集中有多少批次的数据,然后将其中的 20% 移至测试集。

# 从本地路径读入图像数据
print("--------# 从本地路径读入图像数据---------")
data_dir   = "D:/jupyter notebook/DL-100-days/datasets/Cats&Dogs Data2/"
img_height = 224
img_width  = 224
batch_size = 32# 划分训练集
print("--------# 划分训练集---------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 划分验证集
print("--------# 划分验证集---------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 从验证集中划20%的数据用作测试集
print("--------# 从验证集中划20%的数据用作测试集---------")
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)print('验证集的批次数: %d' % tf.data.experimental.cardinality(val_ds))
print('测试集的批次数: %d' % tf.data.experimental.cardinality(test_ds))# 显示数据类别
print("--------# 显示数据类别---------")
class_names = train_ds.class_names
print(class_names)print("--------# 归一化处理---------")
AUTOTUNE = tf.data.AUTOTUNEdef preprocess_image(image,label):return (image/255.0,label)# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 数据可视化
print("--------# 数据可视化---------")
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1) plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

二、数据增强(增加数据集中样本的多样性)

数据增强的常用方法包括(但不限于):随机平移、随机翻转、随机旋转、随机亮度、随机对比度,可以在Tf中文网的experimental/preprocessing类目下查看,也可以在Tf中文网的layers/类目下查看。

本文使用随机翻转随机旋转来进行增强:

tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像

tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像

# 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转。
print("--------# 数据增强:随机翻转+随机旋转---------")
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])# Add the image to a batch.
print("--------# 添加图像到batch中---------")
# Q:这个i从哪来的??????
image = tf.expand_dims(images[i], 0)print("--------# 显示增强后的图像---------")
plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = data_augmentation(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0])plt.axis("off")
--------# 数据增强:随机翻转+随机旋转---------
--------# 添加图像到batch中---------
--------# 显示增强后的图像---------
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.

在这里插入图片描述

三、将增强后的数据添加到模型中

两种方式:

  • (1)将其嵌入model中

优点是:

● 数据增强这块的工作可以得到GPU的加速(如果使用了GPU训练的话)

注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

'''
model = tf.keras.Sequential([data_augmentation,layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),
])
'''
"\nmodel = tf.keras.Sequential([\n  data_augmentation,\n  layers.Conv2D(16, 3, padding='same', activation='relu'),\n  layers.MaxPooling2D(),\n])\n"
  • (2)在Dataset数据集中进行数据增强
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNEdef prepare(ds):ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)return dsprint("--------# 增强后的图像加到模型中---------")
train_ds = prepare(train_ds)

在这里插入图片描述

四、开始训练

# 设置模型
print("--------# 设置模型---------")
model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(class_names))
])# 设置编译参数
# ● 损失函数(loss):用于衡量模型在训练期间的准确率。
# ● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
# ● 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
print("--------# 设置编译器参数---------")
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])print("--------# 开始训练---------")
epochs=20
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)print("--------# 查看训练结果---------")
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

在这里插入图片描述

五、自定义增强函数

print("--------# 自定义增强函数---------")
import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):seed = (random.randint(0,9), 0)# 随机改变图像对比度stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)return stateless_random_brightnessimage = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = aug_img(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0].numpy().astype("uint8"))plt.axis("off")# Q: 将自定义增强函数应用到我们数据上呢?
# 请参考上文的 preprocess_image 函数,将 aug_img 函数嵌入到 preprocess_image 函数中,在数据预处理时完成数据增强就OK啦。

在这里插入图片描述
在这里插入图片描述

# 从本地路径读入图像数据
print("--------# 从本地路径读入图像数据---------")
data_dir   = "D:/jupyter notebook/DL-100-days/datasets/Cats&Dogs Data2/"
img_height = 224
img_width  = 224
batch_size = 32# 划分训练集
print("--------# 划分训练集---------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 划分验证集
print("--------# 划分验证集---------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 从验证集中划20%的数据用作测试集
print("--------# 从验证集中划20%的数据用作测试集---------")
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)print('验证集的批次数: %d' % tf.data.experimental.cardinality(val_ds))
print('测试集的批次数: %d' % tf.data.experimental.cardinality(test_ds))# 显示数据类别
print("--------# 显示数据类别---------")
class_names = train_ds.class_names
print(class_names)print("--------# 归一化处理---------")
AUTOTUNE = tf.data.AUTOTUNEprint("--------# 将自定义增强函数应用到数据上---------")
def preprocess_image(aug_img,label):return (aug_img/255.0,label)# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 数据可视化
print("--------# 数据可视化---------")
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1) plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")# 设置模型
print("--------# 设置模型---------")
model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(class_names))
])# 设置编译参数
# ● 损失函数(loss):用于衡量模型在训练期间的准确率。
# ● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
# ● 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
print("--------# 设置编译器参数---------")
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])print("--------# 开始训练---------")
epochs=20
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)print("--------# 查看训练结果---------")
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

使用自定义增强函数增强后的数据重新训练的结果:
在这里插入图片描述

六、一些增强函数

在这里插入图片描述
(1)随机亮度(RandomBrightness)

tf.keras.layers.RandomBrightness( factor, value_range=(0, 255), seed=None, **kwargs )

(2)随机对比度(RandomContrast)

tf.keras.layers.RandomContrast( factor, seed=None, **kwargs )

(3)随机裁剪(RandomCrop)

tf.keras.layers.RandomCrop( height, width, seed=None, **kwargs )

(4)随机翻转(RandomFlip)

tf.keras.layers.RandomFlip( mode=HORIZONTAL_AND_VERTICAL, seed=None, **kwargs )
(5)随机高度(RandomHeight)和随机宽度(RandomWidth)

tf.keras.layers.RandomHeight( factor, interpolation='bilinear', seed=None, **kwargs )

tf.keras.layers.RandomWidth( factor, interpolation='bilinear', seed=None, **kwargs )

(6)随机平移(RandomTranslation)

tf.keras.layers.RandomTranslation( height_factor, width_factor, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs )

(7)随机旋转(RandonRotation)

tf.keras.layers.RandomRotation( factor, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs )

(8)随机缩放(RandonZoom)

tf.keras.layers.RandomZoom( height_factor, width_factor=None, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/198865.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Scrapy----Scrapy简介

文章目录 概述与应用背景架构和组件功能和特点社区生态概述与应用背景 Scrapy,一个高效、灵活、且强大的Web爬取框架,被广泛应用于数据抓取和网页内容的结构化提取。它是用Python编写的,支持多平台运行,适用于数据挖掘、在线零售信息收集、历史数据存档等多种场景。Scrapy…

thinkphp8 DB_PREFIX 属性

设计表的时候使用**_user, **就是前缀,DB_PREFIX就是默认把前缀给去掉 在config/database.php prefix,改成你的前缀,数据库的表重命名‘ltf_user’ 代码调用 $user Db::name("user")->select();return json($user);之前是使用…

4.1 Linux操作系统

个人主页:Lei宝啊 愿所有美好如期而遇 我们上次了解了从操作系统往下的部分:Linux进程前导知识 这一次,我们将正式开始进程以及操作系统(OS)及其以上的部分。 我们将操作系统的内存管理,进程管理,文件管理等等称为操…

安防视频监控平台EasyCVR服务器部署后出现报错,导致无法级联到域名服务器,该如何解决?

视频监控平台EasyCVR能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,在视频监控播放上,安防监控平台可支持1、4、9、16个画面窗口播放,可同时播放多路视频流,也能支持视频定时轮播。视频监控…

Git详解

Git是一个开源的分布式版本控制系统,常用于软件开发中对代码版本管理。Git具有版本控制、协作开发、分支管理、代码审查等功能,能够记录每次代码修改的内容和时间,并能够回滚到任意历史版本,方便团队协作和代码维护。 Git的基本概…

在通用jar包中引入其他spring boot starter,并在通用jar包中直接配置这些starter的yml相关属性

场景 我在通用jar包中引入 spring-boot-starter-actuator 这样希望引用通用jar的所有服务都可以直接使用 actuator 中的功能, 问题在于,正常情况下,actuator的配置都写在每个项目的yml文件中,这就意味着,虽然每个项目…

数据结构【DS】栈

共享栈 共享栈的目的是什么? 目的:有效利用存储空间。 共享栈的存取数据时间复杂度为? 存取数据时间复杂度为O(1) 共享栈如何判空?如何判满? 两个栈的栈顶指针都指向栈顶元素,𝑡𝑜𝑝…

4-5学生分数对应的成绩

![#include<stdio.h> int main(){float score;char grade;for(int i0;i<7;i){printf("请输入成绩&#xff1a;");scanf("%f",&score);while(score>100||score<0){printf("\n输入的成绩有误&#xff0c;请重新输入&#xff1a;&quo…

【hacker送书第2期】计算机考研精炼1000题

第2期图书推荐 计算机考研难点《计算机考研精炼1000题》揭秘问答1. 为什么是1000题&#xff1f;2. 有什么优势&#xff1f;3. 编写团队水平如何&#xff1f;4. 题目及解析品质如何&#xff1f;可以试读吗&#xff1f; 参与方式 计算机考研难点 难度高&#xff01;知识点多&…

day29_Servlet

今日内容 零、 复习昨日 一、Servlet 零、 复习昨日 一、Servlet 1.1 Servlet介绍 javaweb开发,就是需要服务器接收前端发送的请求,以及请求中的数据,经过处理(jdbc操作),然后向浏览器做出响应. 我们要想在服务器中写java代码来接收请求,做出响应,我们的java代码就得遵循tomca…

旋极携手西班牙SoC-e公司,为中国客户提供高效可靠TSN通讯解决方案

2023年2月&#xff0c;旋极信息与西班牙SoC-e公司正式签订战略合作协议&#xff0c;成为其在中国区重要合作伙伴。 SoC-e是一家世界领先的基于FPGA技术的以太网通讯解决方案供应商&#xff0c;是一系列IP核开发领域的先锋&#xff0c;为关键任务实施网络化、同步性和安全性提供…

深入理解 synchronized 原理

目录 一. 前言 二. Java对象的内存布局 2.1. 对象头 2.2. Mark Word 2.3. Class Metadata Pointer 2.4. Length 三. 偏向锁 3.1. 偏向锁的工作流程 3.2. 偏向失效 3.2.1. 误区一 3.3. 偏向撤销 3.3.1. 误区一 3.4. 偏向撤销的底层实现 3.5. HashCode与偏向撤销 …

【Dubbo】Dubbo负载均衡实现解析

&#x1f4eb;作者简介&#xff1a;小明java问道之路&#xff0c;2022年度博客之星全国TOP3&#xff0c;专注于后端、中间件、计算机底层、架构设计演进与稳定性建设优化&#xff0c;文章内容兼具广度、深度、大厂技术方案&#xff0c;对待技术喜欢推理加验证&#xff0c;就职于…

动态库符号抢占问题分析

背景 前段时间在北汽项目中&#xff0c;遇到了一个奇怪现象&#xff0c;困扰了大家较长时间。最终在和同事的不懈努力下&#xff0c;从根因上解决了该问题&#xff0c;过程中也学习到了很多。在此&#xff0c;记录并分享&#xff0c;希望能够帮助大家。 问题描述 作为OTA服务的…

客户管理系统大盘点!推荐这五款

客户管理系统大盘点&#xff01;推荐这五款。 客户管理系统也就是CRM&#xff0c;可以说是企业刚需&#xff0c;国内外的客户管理系统也是数不胜数&#xff0c;到底有哪些是真正好用&#xff0c;值得推荐的呢&#xff1f;本文将为大家推荐这5款好用的客户管理系统&#xff1a;…

R语言:利用biomod2进行生态位建模

在这里主要是分享一个不错的代码&#xff0c;喜欢的可以慢慢研究。我看了一遍&#xff0c;觉得里面有很多有意思的东西&#xff0c;供大家学习和参考。 利用PCA轴总结的70个环境变量&#xff0c;利用biomod2进行生态位建模&#xff1a; #------------------------------------…

elasticsearch 概述

初识elasticsearch 了解ES elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎&#xff0c;具备非常多强大功能&#xff0c;可以帮助我们从海量数据中快速找到需要的内容 例如&#xff1a; 在GitHub搜索代码 在电商网站搜索商品 ELK技术栈 elasticsearc…

基于FPGA的五子棋(论文+源码)

1.系统设计 在本次设计中&#xff0c;整个系统硬件框图如下图所示&#xff0c;以ALTERA的FPGA作为硬件载体&#xff0c;VGA接口&#xff0c;PS/2鼠标来完成设计&#xff0c;整个系统可以完成人人对战&#xff0c;人机对战的功能。系统通过软件编程来实现上述功能。将在硬件设计…

centos的root密码忘记或失效的解决办法

目录 前言1 单机维护模式2 利用具有管理员权限的用户切换到root用户3 救援模式 前言 在Linux系统中&#xff0c;root用户是最高权限的用户&#xff0c;可以执行任何命令和操作。但是&#xff0c;如果我们忘记了root用户的密码&#xff0c;或者需要修改root用户的密码&#xff…

Spring Boot 项目部署方案!打包 + Shell 脚本部署详解

文章目录 概要一 、profiles指定不同环境的配置二、maven-assembly-plugin打发布压缩包三、 分享shenniu_publish.sh程序启动工具四、linux上使用shenniu_publish.sh启动程序 概要 本篇和大家分享的是springboot打包并结合shell脚本命令部署&#xff0c;重点在分享一个shell程…