【TensorFlow2 之011】TF 如何使用数据增强提高模型性能?

一、说明

        亮点:在这篇文章中,我们将展示数据增强技术作为提高模型性能的一种方式的好处。当我们没有足够的数据可供使用时,这种方法将非常有益。

教程概述:

  1. 无需数据增强的训练
  2. 什么是数据增强?
  3. 使用数据增强进行训练
  4. 可视化

二、没有数据增强的训练

        一个熟悉的问题是“我们为什么要使用数据增强?所以,让我们看看答案。

        为了证明这一点,我们将在TensorFlow中创建一个卷积神经网络,并在Cats-vs-Dog数据集上对其进行训练。

        首先,我们将准备用于训练的数据集。我们将首先从在线存储库下载数据集。完成此操作后,我们将继续解压缩并为训练和验证集创建路径位置。

import os
import wget
import zipfilewget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")
100% [........................................................................] 68606236 / 68606236

Out[2]:

'cats_and_dogs_filtered.zip'

 

with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

       让我们继续加载本教程所需的必要库。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimgfrom tensorflow.keras import Model
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout, Activation
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.image import ImageDataGenerator

        我们将使用“模型子类化技术”来构建模型。这样,我们应该在__init__中定义我们的层,并在调用中实现模型的前向传递。模型的输入将是大小为 \([150, 150, 3]\) 的图像。在卷积层之后,我们将利用两个完全连接的层来进行预测。这是一个二元分类问题,所以我们在输出层中只有一个神经元。

class Create_model(Model):def __init__(self, chanDim=-1):super(Create_model, self).__init__()self.conv1A = Conv2D(16, 3, input_shape = (150, 150, 3))self.act1A  = Activation("relu")self.pool1A = MaxPooling2D(2)self.conv1B = Conv2D(32, 3)self.act1B  = Activation("relu")self.pool1B = MaxPooling2D(pool_size=(2, 2))self.conv1C = Conv2D(64, 3)self.act1C  = Activation("relu")self.pool1C = MaxPooling2D(2)self.flatten = Flatten()self.dense2A = Dense(512)self.act2A  = Activation("relu")self.dense2B = Dense(1)self.sigmoid  = Activation("sigmoid")def call(self, inputs):x = self.conv1A(inputs)x = self.act1A(x)x = self.pool1A(x)x = self.conv1B(x)x = self.act1B(x)x = self.pool1B(x)x = self.conv1C(x)x = self.act1C(x)x = self.pool1C(x)x = self.flatten(x)x = self.dense2A(x)x = self.act2A(x)x = self.dense2B(x)x = self.sigmoid(x)return xmodel = Create_model()model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['accuracy'])

        我们的图像不在一个文件中,而是在多个文件夹中。为了在这样的数据集上训练网络,我们需要使用图像数据生成器。在创建两个生成器(用于训练和验证)后,我们可以使用 fit 方法训练网络。唯一的区别是,我们不是将输入和输出分别传递给我们的网络,而是将数据生成器传递给网络。

        最好规范化像素值,以便每个像素值的值介于 0 和 1 之间,以免中断或减慢学习过程。因此,这将是传递给图像数据生成器的唯一参数。然后,我们将使用这些数据生成器遍历目录、调整图像大小和创建批处理。

train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=20,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary')history = model.fit(train_generator,steps_per_epoch=100,epochs=15,validation_data=validation_generator,validation_steps=50,verbose=0)

        让我们检查一下模型的分类准确性和损失。

        在这里,训练集的准确性和损失都以蓝色显示,而验证集的准确度和损失都以橙色显示。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy)
plt.plot(epochs, val_accuracy)
plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss)
plt.plot(epochs, val_loss)
plt.title('Training and validation loss')

  

       从这些图中,我们可以清楚地看到,模型在训练中的表现比在验证集上的表现要好得多。那么我们能做什么呢?使用数据增强。

三、什么是数据增强?

        大多数计算机视觉任务需要大量数据,而数据增强是用于提高计算机视觉系统性能的技术之一。计算机视觉是一项相当复杂的任务。对于输入图像,算法必须找到一种模式来理解图片中的内容。

        在实践中,拥有更多数据将有助于几乎所有的计算机视觉任务。今天,计算机视觉的状态需要更多的数据来解决大多数计算机视觉问题。对于卷积神经网络的所有应用来说,这可能不是真的,但对于计算机视觉领域来说确实如此。

        当我们训练计算机视觉模型时,数据增强通常会有所帮助。无论我们使用迁移学习还是从头开始训练模型,都是如此。

        因此,数据增强是一种技术,可以在不收集新数据的情况下显着增加可用于训练的数据的多样性。您可以在此处找到有关数据增强理论方面的更多信息。

四、 使用数据增强进行训练

        乍一看,数据增强可能听起来很复杂,但幸运的是,TensorFlow 允许我们有效地实现它。

        因此,我们将像以前一样使用图像数据生成器,但我们将添加重新缩放、旋转、移动、缩放、缩放和翻转。再次重要的是要说,此过程仅适用于训练集,而不应用于验证。

        在这里,我们不仅要调整大小,还要添加旋转(以度为单位的范围)、高度和宽度偏移(以像素为单位的范围)、剪切范围(以度为单位的逆时针方向的角度)、缩放范围和翻转。

train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,)val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=20,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary')

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

        此外,Dropout 层将被添加到我们的神经网络中。我们将丢弃\(50%\)个起始神经元。添加 Dropout 图层将有助于防止过度拟合。

class Create_model(Model):def __init__(self, chanDim=-1):super(Create_model, self).__init__()self.conv1A = Conv2D(16, 3, input_shape = (150, 150, 3))self.act1A  = Activation("relu")self.pool1A = MaxPooling2D(2)self.conv1B = Conv2D(32, 3)self.act1B  = Activation("relu")self.pool1B = MaxPooling2D(pool_size=(2, 2))self.conv1C = Conv2D(64, 3)self.act1C  = Activation("relu")self.pool1C = MaxPooling2D(2)self.flatten = Flatten()self.dense2A = Dense(512)self.act2A  = Activation("relu")self.dropout = Dropout(0.5)self.dense2B = Dense(1)self.sigmoid  = Activation("sigmoid")def call(self, inputs):x = self.conv1A(inputs)x = self.act1A(x)x = self.pool1A(x)x = self.conv1B(x)x = self.act1B(x)x = self.pool1B(x)x = self.conv1C(x)x = self.act1C(x)x = self.pool1C(x)x = self.flatten(x)x = self.dense2A(x)x = self.act2A(x)x = self.dropout(x)x = self.dense2B(x)x = self.sigmoid(x)return xmodel = Create_model()model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['accuracy'])

现在我们可以训练网络了。

history = model.fit(train_generator,steps_per_epoch=100,epochs=30,validation_data=validation_generator,validation_steps=50,verbose=2)

Train for 100 steps, validate for 50 steps
Epoch 1/30
100/100 - 92s - loss: 0.9063 - accuracy: 0.5125 - val_loss: 0.7271 - val_accuracy: 0.5000
Epoch 2/30
100/100 - 56s - loss: 0.7020 - accuracy: 0.5625 - val_loss: 0.6551 - val_accuracy: 0.5480
Epoch 3/30
100/100 - 56s - loss: 0.6815 - accuracy: 0.5950 - val_loss: 0.6253 - val_accuracy: 0.6600
Epoch 4/30
100/100 - 57s - loss: 0.6594 - accuracy: 0.6220 - val_loss: 0.6262 - val_accuracy: 0.6350
Epoch 5/30
100/100 - 56s - loss: 0.6352 - accuracy: 0.6485 - val_loss: 0.5916 - val_accuracy: 0.6890
Epoch 6/30
100/100 - 56s - loss: 0.6336 - accuracy: 0.6675 - val_loss: 0.5774 - val_accuracy: 0.6790
Epoch 7/30
100/100 - 57s - loss: 0.6383 - accuracy: 0.6570 - val_loss: 0.5830 - val_accuracy: 0.6980


让我们看看结果。在这里,训练集的准确性和损失都以蓝色显示,而验证集的准确度和损失都以橙色显示。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy)
plt.plot(epochs, val_accuracy)
plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss)
plt.plot(epochs, val_loss)
plt.title('Training and validation loss')
Text(0.5, 1.0, 'Training and validation loss')

现在的结果好多了。但是,我们模型的准确性还不完美。

我们将在下一篇文章中使用迁移学习来解决这个问题。

五、 可视化

到目前为止,我们只是在讨论如何创建增强图像,但让我们看看它们的外观。

为此,我们需要使用来自生成器的一个图像并“循环它”。这将遍历生成器并执行增强。下面我们展示了这些图像样本的可视化。

augmented_images = [train_generator[0][0][0] for i in range(12)]
plt.figure(figsize=(8,6))for i in range(12):plt.subplot(3, 4, i+1)image = augmented_images[i]image = image.reshape(150, 150, 3)plt.imshow(image)
pyplot.show()
狗增强图像
增强图像

六、总结

        总而言之,我们已经学会了如何使用数据增强技术来提高模型的性能。在数据稀缺或数据收集成本高昂的情况下,我们可以使用这种方法。但是,请注意,我们不能将数据集扩充到非常大的比例。此方法有其局限性。在下一篇文章中,我们将展示如何应用迁移学习的过程。

有关该主题的更多资源:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/156829.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

隐马尔可夫模型(一)Evaluation

前提 import torch import torch.nn.functional as F N 3 # 离散隐变量可以取到的值的个数 M 2 # 可观测变量个数 pi F.softmax(torch.randn((N,1),dtypetorch.float32),dim0) # 初始状态概率矩阵 A F.softmax(torch.randn((N,N),dtypetorch.float32),dim-1) # 转移…

Linux8yum安装mysql5.7版本流程

Linux8yum安装mysql Linux8yum安装报错解决 yum安装流程 首先下载mysql的yum配置 wget http://dev.mysql.com/get/mysql57-community-release-el7-11.noarch.rpm安装mysql源 yum -y install mysql57-community-release-el7-11.noarch.rpm安装mysql yum -y install mysql-s…

【目标检测】大图包括标签切分,并转换成txt格式

前言 遥感图像比较大,通常需要切分成小块再进行训练,之前写过一篇关于大图裁切和拼接的文章【目标检测】图像裁剪/标签可视化/图像拼接处理脚本,不过当时的工作流是先将大图切分成小图,再在小图上进行标注,于是就不考…

Elasticsearch:使用 Langchain 和 OpenAI 进行问答

这款交互式 jupyter notebook 使用 Langchain 将虚构的工作场所文档拆分为段落 (chunks),并使用 OpenAI 将这些段落转换为嵌入并将其存储到 Elasticsearch 中。然后,当我们提出问题时,我们从向量存储中检索相关段落,并使用 langch…

Google SGE 正在添加人工智能图像生成器,现已推出:从搜索中的生成式 AI 中获取灵感的新方法

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

1688拍立淘接口,按图搜索1688商品接口,图片搜索商品接口,1688API接口

按图搜索1688商品的方法如下: 打开1688平台,点击首页右上角的搜索框,进入搜索页面。 点击搜索框右侧的相机图标,选择“拍照”或“相册”,上传你想要搜索的图片。 等待图片上传完成,系统会自动识别图片中的…

基于Effect的组件设计 | 京东云技术团队

Effect的概念起源 从输入输出的角度理解Effect https://link.excalidraw.com/p/readonly/KXAy7d2DlnkM8X1yps6L 编程中的Effect起源于函数式编程中纯函数的概念 纯函数是指在相同的输入下,总是产生相同的输出,并且没有任何副作用(side effect)的函数。…

Qt QMultiMap

QMultiMap 文章目录 QMultiMap摘要QMultiMapQMultiMap 特点代码示例 关键字: Qt、 QMultiMap、 容器、 键值、 键值重复 摘要 今天在观摩小伙伴撸代码的时候,突然听到了QMultiMap自己使用Qt开发这么就,竟然都不知道,所以趁没…

sentinel的启动与运行

首先我们github下载sentinel Releases alibaba/Sentinel (github.com) 下载好了后输入命令让它运行即可,使用cmd窗口输入一下命令即可 java -Dserver.port8089 -jar sentinel-dashboard-1.8.6.jar 账号密码默认都是sentinel 启动成功后登录进去效果如下

ABAP 采购组 条目 Z001 不存在T161内-请检查输入

背景:在ALV报表更改PR采购组 做法:ALV报表取出PR相关数据,直接将采购组列设置为可编辑,然后设置按钮更改逻辑。 操作:将采购组值更新(从原来500改为600),然后点更改功能按钮&#xf…

汽车一键启动点火开关按键一键启动按钮型号规格

汽车点火开关/移动管家一键启动按键/汽车改装引擎启动按钮型号:YD828溥款开关 一键启动按钮(适用于配套启动主机使用或原车一键启动开关更换) 1.适合配套专用板板安装 2.开孔器开孔安装 3.原车钥匙位安装 外观:黑色 按钮上有3种不…

SpringBoot 前端406 后端Could not find acceptable representation

原因:返回对象没有get方法,无法转成JSON格式

Elasticsearch:什么是检索增强生成 - RAG?

在人工智能的动态格局中,检索增强生成(Retrieval Augmented Generation - RAG)已经成为游戏规则的改变者,彻底改变了我们生成文本和与文本交互的方式。 RAG 使用大型语言模型 (LLMs) 等工具将信息检索的能力与自然语言生成无缝结合…

Linux系统下centos中在线添加硬盘后不重启在线扩容linux系统目录不重启系统

Centos7 在线添加硬盘不重启系统 CentOS 7在线添加新磁盘,无需重启 现有环境基本都是线下server以及线上虚拟机等,几乎都支持热插拔,热扩容,所以在线添加新磁盘就尤为重要,这样可以无需中断当前服务或进程也可对其进行添加硬盘操作。 1.添加硬盘: 虚拟机在线状态下对其进行添加…

centos下安装配置redis7

1、找个目录下载安装包 sudo wget https://download.redis.io/release/redis-7.0.0.tar.gz 2、将tar.gz包解压至指定目录下 sudo mkdir /home/redis sudo tar -zxvf redis-7.0.0.tar.gz -C /home/redis 3、安装gcc-c yum install gcc-c 4、切换到redis-7.0.0目录下 5、修改…

【RabbitMQ】docker rabbitmq集群 docker搭建rabbitmq集群

docker rabbitmq集群 docker搭建rabbitmq集群 RabbitMQ提供了两种常用的集群模式 1.普通集群模式 2.镜像集群模式 普通集群模式只能同步主节点上的交换机和队列信息,但对于队列中的消息不做同步,主节点宕机也不能进行切换(故障转移&#xff…

【Python】PaddleOCR文字识别国产之光 从安装到pycharm中测试 (保姆级图文)

目录 官方项目地址Python环境搭建(也就是使用Anaconda的python)1. 安装Anaconda1. 打开终端并创建conda环境 安装PaddlePaddle(CPU演示)安装PaddleOCR whl包如果安装shapely库报错(我没有报错,其他类似库安…

Pygame中将鼠标形状设置为图片2-2

3 编写主程序 在主程序中,首先创建屏幕并且完成一些准备工作,之后在while循环中不断更新sprite实例即可。 3.1 创建屏幕及准备工作 创建屏幕及准备工作的代码如图5所示。 图5 创建屏幕及准备工作 其中,第20行代码调用pygame.mouse模块中的…

pycharm设置pyuic和pyrcc

pyuic设置 适合任何虚拟环境,直接用虚拟环境的python解决一切。。。 E:\anaconda3\envs\qt5\python.exe-m PyQt5.uic.pyuic $FileName$ -o $FileNameWithoutExtension$.py$FileDir$pyrcc设置 E:\anaconda3\envs\qt5\python.exe-m PyQt5.pyrcc_main $FileName$ -o…

《机器学习》第5章 神经网络

文章目录 5.1 神经元模型5.2 感知机与多层网络5.3 误差逆传播算法5.4 全局最小与局部最小5.5 其他常见神经网络RBF网络ART网络SOM网络级联相关网络Elman网络Boltzmann机 5.6 深度学习 5.1 神经元模型 神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它…