tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • 这次主要是学习数据增强, 训练集 验证集 测试集的构建等等的基本方法, 数据集还是用的上一篇的猫狗识别;
  • 基础篇还剩下几个, 后面的难度会逐步提升;
  • 欢迎收藏 + 关注, 本人会持续更新.

文章目录

  • 1. 简介
    • 数据增强
    • 训练集划分
  • 2. 案例测试
    • 1. 数据处理
      • 1. 导入库
      • 2. 导入数据(训练集 测试集 验证集)
      • 3. 数据部分展示
      • 4. 数据归一化与内存加速
      • 5. 数据增强
      • 6. 将增强数据融合到原始数据中
    • 2. 模型创建
    • 3. 模型训练
      • 1. 超参数设置
      • 2. 模型训练
      • 3. 模型测试
    • 4. 其他方法增强数据

1. 简介

数据增强

💙 有时候数据很好, **就可以通过在原有的基础上做一些操作, ** 从而增加数据的数量, 使训练模型更加有效.

📶 对于基础的的增强, 一般就是旋转, 在pytorch中一般是用transforms.Compose进行处理, 在tensorflow中,一般用的是tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强, 👁 具体做法请看案例

当然还有其他的方法进行增强, 比如说添加噪音, 👓 详情请看第四节, 4. 其他方法数据增强

数据增强加入模型中

一般有两个方法:

  1. 加入数据集(本文用的方法)
  2. 加入到模型中, 让模型训练的时候, 开始进行数据增强, 这个本文不介绍

注意: tensorflow和numpy版本问题不同, 可能会出现比较多数据方面的错误, 本人这个案例最后也是在云平台上跑通的.

训练集划分

简单说一下训练集, 测试集, 验证集的区别:

  • 训练集: 用来训练模型的, 确定神经网络的各种参数, 相当于我们学习一样
  • 验证集: 在训练集中, 通过验证模型效果, 来调整模型参数, 这个就相当于我们月考一样
  • 测试集: 这个就是验证模型是都具有效果, 适用于其他数据, 这个就相当于我们大考

👀 在tensorflow中, 我们可以通过tf.keras.preprocessing.image_dataset_from_directory创建训练集和验证集, 但是不能创建测试集, 创建测试集的方法, 需要我们后面对数据进行分类, 如下:

val_batches = tf.data.experimental.cardinality(val_ds)
# 创建测试集,  方法: 将验证集合拆成 5 分, 测试集占一份, 验证集占 4 份
test_ds = val_ds.take(val_batches // 2)    # 取前 * 批次
val_ds = val_ds.skip(val_batches // 2)     # 除了前 * 批次

解释:

  • tf.data.experimental.cardinality获取数据批次大小
  • .take : 取前n批数据
  • .skip : 取除了前n批次数据

2. 案例测试

本次案例是对猫狗图像进行分类, 和上一期很像, 但是这个模型使用比较简单.

注意: 不同池化层, 效果有时候天差地别, 比如说: 这个案例用的是最大池化, 但是用平均池化的话, 效果极差

1. 数据处理

1. 导入库

import tensorflow as tf 
from tensorflow.keras import layers, models, datasets 
import numpy as np gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)   # 输出存储在GPUtf.config.set_visible_devices([gpu0], "GPU")          # 选择第一块GPUgpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2. 导入数据(训练集 测试集 验证集)

# 查看数据目录
import os, pathlibdata_dir = "./data/"
data_dir = pathlib.Path(data_dir)classnames = [str(path) for path in os.listdir(data_dir)]
classnames
['cat', 'dog']
# 创建训练集和验证集batch_size = 32
image_width, image_height = 224, 224train_ds = tf.keras.preprocessing.image_dataset_from_directory('./data/',subset='training',validation_split=0.3,batch_size=batch_size,image_size=(image_width, image_height),shuffle=True,seed=42
)val_ds = tf.keras.preprocessing.image_dataset_from_directory('./data',subset='validation',validation_split=0.3,batch_size=batch_size,image_size=(image_width, image_height),shuffle=True,seed=42
)
Found 600 files belonging to 2 classes.
Using 420 files for training.
Found 600 files belonging to 2 classes.
Using 180 files for validation.

在tensorflow没有提供直接分割测试集的函数,但是可以通过分割验证集的方法进行创建测试集

val_batches = tf.data.experimental.cardinality(val_ds)
# 创建测试集,  方法: 将验证集合拆成 5 分, 测试集占一份, 验证集占 4 份
test_ds = val_ds.take(val_batches // 2)    # 取前 * 批次
val_ds = val_ds.skip(val_batches // 2)     # 取除了前 * 批次print("test batches: %d"%tf.data.experimental.cardinality(test_ds))
print("val batches: %d"%tf.data.experimental.cardinality(val_ds))
test batches: 3
val batches: 3

训练集: 验证集: 测试集 = 0.7 : 0.15 : 0.15

3. 数据部分展示

# 数据规格展示
for images, labels in train_ds.take(1):print("image: [N, W, H, C] ", images.shape)print("labels: ", labels)break
image: [N, W, H, C]  (32, 224, 224, 3)
labels:  tf.Tensor([0 1 1 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 1 1 0 1 1 0 1 0 1 0], shape=(32,), dtype=int32)
# 部分图片数据展示
import matplotlib.pyplot as plttrain_one_batch = next(iter(train_ds))plt.figure(figsize=(20, 10))images, labels = train_one_batchfor i in range(20):plt.subplot(5, 10, i + 1)plt.title(classnames[labels[i]])plt.imshow(images[i].numpy().astype('uint8'))plt.axis('off')plt.show()


在这里插入图片描述

4. 数据归一化与内存加速

from tensorflow.data.experimental import AUTOTUNE # 像素归一化, ---> [0, 1]
normalization_layer = layers.experimental.preprocessing.Rescaling(1.0 / 255)# 训练集、测试集像素归一化
train_ds = train_ds.map(lambda x, y : (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y : (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y : (normalization_layer(x), y))# 设置内存加速
AUTOTUNE = tf.data.experimental.AUTOTUNE # 打乱顺序加速, 测试集就不必了哈
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

5. 数据增强

我们可以使用 tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强.

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像.
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像.
# 封装整合
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),   # 垂直和水平反转tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)                      # 随机翻转
])test_datas = next(iter(train_ds))test_images, test_labels = test_datas# 随机选取一个
test_image = tf.expand_dims(test_images[i], 0)plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = data_augmentation(test_image)   # 旋转ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0])                 plt.axis("off")


在这里插入图片描述

6. 将增强数据融合到原始数据中

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNEdef prepare(ds):ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)return ds# 增强
train_ds = prepare(train_ds)

2. 模型创建

model = models.Sequential([# 第一层要输入维度layers.Conv2D(16, (3, 3), activation='relu', input_shape=(image_width, image_height, 3)),layers.MaxPooling2D((2,2)),layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2,2)),layers.Dropout(0.3),layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2,2)),layers.Dropout(0.3),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(classnames))
])model.summary()
Model: "sequential_1"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 222, 222, 16)      448       max_pooling2d (MaxPooling2D  (None, 111, 111, 16)     0         )                                                               conv2d_1 (Conv2D)           (None, 109, 109, 32)      4640      max_pooling2d_1 (MaxPooling  (None, 54, 54, 32)       0         2D)                                                             dropout (Dropout)           (None, 54, 54, 32)        0         conv2d_2 (Conv2D)           (None, 52, 52, 32)        9248      max_pooling2d_2 (MaxPooling  (None, 26, 26, 32)       0         2D)                                                             dropout_1 (Dropout)         (None, 26, 26, 32)        0         flatten (Flatten)           (None, 21632)             0         dense (Dense)               (None, 128)               2769024   dense_1 (Dense)             (None, 2)                 258       =================================================================
Total params: 2,783,618
Trainable params: 2,783,618
Non-trainable params: 0
_________________________________________________________________

3. 模型训练

1. 超参数设置

opt = tf.keras.optimizers.Adam(learning_rate=0.001)  # 学习率:0.001model.compile(optimizer = opt,loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics = ['accuracy']
)

2. 模型训练

epochs=20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,verbose=1
)
Epoch 1/20
2024-11-22 18:03:21.866630: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8101
2024-11-22 18:03:23.553540: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
14/14 [==============================] - 4s 36ms/step - loss: 0.7059 - accuracy: 0.5643 - val_loss: 0.6646 - val_accuracy: 0.6667
Epoch 2/20
14/14 [==============================] - 0s 27ms/step - loss: 0.6125 - accuracy: 0.6381 - val_loss: 0.6096 - val_accuracy: 0.7143
Epoch 3/20
14/14 [==============================] - 0s 15ms/step - loss: 0.5027 - accuracy: 0.7714 - val_loss: 0.5646 - val_accuracy: 0.7500
Epoch 4/20
14/14 [==============================] - 0s 14ms/step - loss: 0.4723 - accuracy: 0.7952 - val_loss: 0.5496 - val_accuracy: 0.7500
Epoch 5/20
14/14 [==============================] - 0s 14ms/step - loss: 0.4395 - accuracy: 0.7857 - val_loss: 0.6267 - val_accuracy: 0.7024
Epoch 6/20
14/14 [==============================] - 0s 13ms/step - loss: 0.3721 - accuracy: 0.8262 - val_loss: 0.5001 - val_accuracy: 0.7619
Epoch 7/20
14/14 [==============================] - 0s 14ms/step - loss: 0.4041 - accuracy: 0.8238 - val_loss: 0.4595 - val_accuracy: 0.7857
Epoch 8/20
14/14 [==============================] - 0s 13ms/step - loss: 0.3195 - accuracy: 0.8643 - val_loss: 0.4247 - val_accuracy: 0.8095
Epoch 9/20
14/14 [==============================] - 0s 13ms/step - loss: 0.3010 - accuracy: 0.8738 - val_loss: 0.3674 - val_accuracy: 0.8452
Epoch 10/20
14/14 [==============================] - 0s 14ms/step - loss: 0.3190 - accuracy: 0.8762 - val_loss: 0.3660 - val_accuracy: 0.8452
Epoch 11/20
14/14 [==============================] - 0s 15ms/step - loss: 0.2864 - accuracy: 0.8690 - val_loss: 0.3529 - val_accuracy: 0.8333
Epoch 12/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2532 - accuracy: 0.8762 - val_loss: 0.2737 - val_accuracy: 0.8929
Epoch 13/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2374 - accuracy: 0.9000 - val_loss: 0.2939 - val_accuracy: 0.8810
Epoch 14/20
14/14 [==============================] - 0s 15ms/step - loss: 0.2216 - accuracy: 0.8976 - val_loss: 0.2952 - val_accuracy: 0.8810
Epoch 15/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2365 - accuracy: 0.9095 - val_loss: 0.2559 - val_accuracy: 0.9167
Epoch 16/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2114 - accuracy: 0.9071 - val_loss: 0.2702 - val_accuracy: 0.8929
Epoch 17/20
14/14 [==============================] - 0s 15ms/step - loss: 0.2075 - accuracy: 0.9024 - val_loss: 0.2353 - val_accuracy: 0.9286
Epoch 18/20
14/14 [==============================] - 0s 13ms/step - loss: 0.1850 - accuracy: 0.9262 - val_loss: 0.1927 - val_accuracy: 0.9524
Epoch 19/20
14/14 [==============================] - 0s 13ms/step - loss: 0.1318 - accuracy: 0.9524 - val_loss: 0.1837 - val_accuracy: 0.9286
Epoch 20/20
14/14 [==============================] - 0s 15ms/step - loss: 0.1561 - accuracy: 0.9476 - val_loss: 0.1951 - val_accuracy: 0.9643

3. 模型测试

loss, acc = model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: ", acc)
3/3 [==============================] - 0s 8ms/step - loss: 0.2495 - accuracy: 0.9062
Loss:  0.24952644109725952
Accuracy:  0.90625

测试集准确率高, 模型效果良好

4. 其他方法增强数据

这里是使数据变得模糊

import random def aug_img(image):seed = (random.randint(0, 9), 0)stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)return stateless_random_brightness
# 随机选取一张照片
image = tf.expand_dims(test_images[i] * 255, 0)   # 注意: 不乘255, 会出现黑色, 因为 像素在0 - 1中plt.figure(figsize=(8,8))
for i in range(9):image_show = aug_img(image)plt.subplot(3, 3, i + 1)plt.imshow(image_show[0].numpy().astype("uint8"))


在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/476849.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot多环境+docker集成企业微信会话存档sdk

SpringBoot多环境docker集成企业微信会话存档sdk 文章来自于 https://developer.work.weixin.qq.com/community/article/detail?content_id16529801754907176021 SpringBoot多环境docker集成企业微信会话存档sdk 对于现在基本流行的springboot环境,官方文档真是比…

VSCode快速生成vue组件模版

1&#xff0c;点击设置&#xff0c;找到代码片段 2&#xff0c;搜索vue&#xff0c;打开vue.json 3&#xff0c;添加模版 vue2模板 "vue2": {"prefix": "vue2","body": ["<template>"," <div>$0</di…

【爬虫】Firecrawl对京东热卖网信息爬取(仅供学习)

项目地址 GitHub - mendableai/firecrawl: &#x1f525; Turn entire websites into LLM-ready markdown or structured data. Scrape, crawl and extract with a single API. Firecrawl更多是使用在LLM大模型知识库的构建&#xff0c;是大模型数据准备中的一环&#xff08;在…

VXLAN说明

1. 什么是 VXLAN &#xff1f; VXLAN&#xff08;Virtual Extensible LAN&#xff0c;虚拟扩展局域网&#xff09;是一种网络虚拟化技术&#xff0c;旨在通过在现有的物理网络上实现虚拟网络扩展&#xff0c;从而克服传统 VLAN 的一些限制。 VXLAN 主要用于数据中心、云计算环…

RTL8211F 1000M以太网PHY指示灯

在RK3562 Linux5.10 SDK里面已支持该芯片kernel-5.10/drivers/net/phy/realtek.c&#xff0c;而默认是没有去修改到LED配置的&#xff0c;我们根据硬件设计修改相应的寄存器配置&#xff0c;该PHY有3个LED引脚&#xff0c;我们LED0不使用&#xff0c;LED1接绿灯&#xff08;数据…

主IP地址与从IP地址:深入解析与应用探讨

在互联网的浩瀚世界中&#xff0c;每台联网设备都需要一个独特的身份标识——IP地址。随着网络技术的不断发展&#xff0c;IP地址的角色日益重要&#xff0c;而“主IP地址”与“从IP地址”的概念也逐渐进入人们的视野。这两个术语虽然看似简单&#xff0c;实则蕴含着丰富的网络…

【Redis】基于Redis实现秒杀功能

业务的流程大概就是&#xff0c;先判断优惠卷是否过期&#xff0c;然后判断是否有库存&#xff0c;最好进行扣减库存&#xff0c;加入全局唯一id&#xff0c;然后生成订单。 一、超卖问题 真是的场景下可能会有超卖问题&#xff0c;比如开200个线程进行抢购&#xff0c;抢100个…

计算机网络socket编程(4)_TCP socket API 详解

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 计算机网络socket编程(4)_TCP socket API 详解 收录于专栏【计算机网络】 本专栏旨在分享学习计算机网络的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&…

Jmeter数据库压测之达梦数据库的配置方法

目录 1、概述 2、测试环境 3、数据库压测配置 3.1 安装jmeter 3.2 选择语言 3.3 新建测试计划 3.4 配置JDBC连接池 3.5 配置线程组 3.6 配置测试报告 3.7 执行测试 1、概述 Jmeter是Apache组织开发的基于Java的压力测试工具&#xff0c;用于对软件做压力测试。 它最…

RAG与微调:大模型落地的最佳路径选择(文末赠书)

一、大模型技术发展现状 自2022年底ChatGPT掀起AI革命以来&#xff0c;大语言模型&#xff08;LLM&#xff09;技术快速迭代发展&#xff0c;从GPT-4到Claude 2&#xff0c;从文心一言到通义千问&#xff0c;大模型技术以惊人的速度发展。然而&#xff0c;在企业实际应用场景中…

Web 入门

HTTP 一、概念 Hyper Text Transfer Protocol&#xff0c;超文本传输协议&#xff0c;规定了浏览器和服务器之间数据传输的规则。 二、特点 基于TCP协议&#xff1a;面向连接&#xff0c;安全。基于请求-响应模型的&#xff1a;一次请求对应一次响应。HTTP协议是无状态的协…

pinia是什么?pinia简介快速入门,创建pinia到vue3项目中

一&#xff0c;pinia就是Vuex&#xff0c;的替代工具&#xff0c;Vuex plus 如何将pinia引入到vue3项目中&#xff1f; 1.首先新建一个vue3项目 全填yes npm init vuelatest 2.安装好之后查阅官方文档 pinia使用文档 3.从而得知在项目中有俩种方式安装pinia 我的本地只有nod…

Java 基于SpringBoot+vue框架的老年医疗保健网站

大家好&#xff0c;我是Java徐师兄&#xff0c;今天为大家带来的是Java Java 基于SpringBootvue框架的老年医疗保健网站。该系统采用 Java 语言开发&#xff0c;SpringBoot 框架&#xff0c;MySql 作为数据库&#xff0c;系统功能完善 &#xff0c;实用性强 &#xff0c;可供大…

FPGA实现串口升级及MultiBoot(九)BPI FLASH相关实例演示

本文目录索引 区别一:启动流程的区别区别二:高位地址处理区别三:地址映射例程说明总结例程地址之前一直都是以SPI FLASH为例进行相关知识讲解,今天我们介绍另一款常用的配置FLASH-BPI FLASH。 今天的讲解以简洁为主,主打个能用一句话不说两句话。以和SPI区别为主,实例演…

VisionPro 机器视觉案例 之 彩色保险丝个数统计

第十四篇 机器视觉案例 之 彩色保险丝颜色识别个数统计 文章目录 第十四篇 机器视觉案例 之 彩色保险丝颜色识别个数统计1.案例要求2.实现思路2.1 方法一 颜色分离工具CogColorSegmenterTool将每一种颜色分离出来&#xff0c;得到对应的单独图像&#xff0c;使用斑点工具CogBlo…

实时数据研发 | Flink技术栈

下周要开始接触一些实时的内容了&#xff0c;想来是很幸运的&#xff0c;这是我在新人培训上提问过技术前辈的问题&#xff1a;“想学习实时相关技术&#xff0c;但是部门没有类似的需求&#xff0c;应该如何提升&#xff1f;”当时师姐说先用心去学&#xff0c;然后向主管证明…

Spring cloud 一.Consul服务注册与发现(4)

1.动态刷新案例步骤 1.问题 接着上一步,我们在consul的dev配置分支修改了内容马上访问,结果无效 会发现还是原来的内容&#xff0c;/(ㄒoㄒ)/~~ &#xff0c;没有做到及时响应和动态刷新 2.步骤 RefreshScope主启动类添加 package com.atguigu.cloud;import org.springfram…

石油化工调度台的外观如何设计更有科技感

在石油化工行业中&#xff0c;调度台作为生产运营的核心指挥中枢&#xff0c;其设计不仅关乎操作效率&#xff0c;更是企业形象和技术实力的体现。那么&#xff0c;到底如何在调度台的外观设计中融入科技感&#xff0c;以提升工作效率并彰显企业前沿形象&#xff0c;成为了一个…

【机器学习】——朴素贝叶斯模型

&#x1f4bb;博主现有专栏&#xff1a; C51单片机&#xff08;STC89C516&#xff09;&#xff0c;c语言&#xff0c;c&#xff0c;离散数学&#xff0c;算法设计与分析&#xff0c;数据结构&#xff0c;Python&#xff0c;Java基础&#xff0c;MySQL&#xff0c;linux&#xf…

如何使用Jest测试你的React组件

在本文中&#xff0c;我们将了解如何使用Jest&#xff08;Facebook 维护的一个测试框架&#xff09;来测试我们的React组件。我们将首先了解如何在纯 JavaScript 函数上使用 Jest&#xff0c;然后再了解它提供的一些开箱即用的功能&#xff0c;这些功能专门用于使测试 React 应…