pytorch03:transforms常见数据增强操作

目录

  • 一、数据增强
  • 二、transforms--Crop裁剪
    • 2.1 transforms.CenterCrop
    • 2.2 transforms.RandomCrop
    • 2.3 RandomResizedCrop
    • 2.4 FiveCrop和TenCrop
  • 三、transforms—Flip翻转、旋转
    • 3.1RandomHorizontalFlip和RandomVerticalFlip
    • 3.2 RandomRotation
  • 四、transforms —图像变换
    • 4.1 transforms.Pad
    • 4.2 transforms.ColorJitter
    • 4.3 Grayscale和RandomGrayscale
    • 4.4 RandomAffine
    • 4.5 RandomErasing
  • 五、transforms的操作
    • 5.1 transforms.RandomChoice
    • 5.2 transforms.RandomApply
    • 5.3 transforms.RandomOrder
  • 六、自定义transforms
    • 6.1 自定义transforms要素
    • 6.2 通过类实现多参数传入
    • 6.3 椒盐噪声
    • 6.4 自定义transforms代码实现
  • 七、数据增强策略
    • 数据增强代码实现

一、数据增强

   数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。如下是对一张图片常见的增强操作例如:旋转、裁剪、像素抖动。
在这里插入图片描述

二、transforms–Crop裁剪

2.1 transforms.CenterCrop

功能:从图像中心裁剪图片
• size:所需裁剪图片尺寸

2.2 transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• padding:设置填充大小
  当为a时,上下左右均填充a个像素,
  当为(a, b)时,上下填充b个像素,左右填充a个像素,
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式
  1、constant:像素值由fill设定
  2、edge:像素值由图像边缘像素决定
  3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
• fill:constant时,设置填充的像素值

2.3 RandomResizedCrop

功能:随机大小、长宽比裁剪图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例, 默认(0.08, 1)
• ratio:随机长宽比,默认(3/4, 4/3)
• interpolation:插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

2.4 FiveCrop和TenCrop

  功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• vertical_flip:是否垂直翻转

三、transforms—Flip翻转、旋转

3.1RandomHorizontalFlip和RandomVerticalFlip

在这里插入图片描述

功能:依概率水平(左右)或垂直(上下)翻转图片
• p:翻转概率

3.2 RandomRotation

功能:随机旋转图片
在这里插入图片描述
在这里插入图片描述

• degrees:旋转角度
  当为a时,在(-a,a)之间选择旋转角度
  当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图

四、transforms —图像变换

4.1 transforms.Pad

功能:对图片边缘进行填充
在这里插入图片描述
• padding:设置填充大小
  当为a时,上下左右均填充a个像素
  当为(a, b)时,上下填充b个像素,左右填充a个像素
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• padding_mode:填充模式,有4种模式,constant、edge、reflect和symmetric
• fill:constant时,设置填充的像素值,(R, G, B) or (Gray)

4.2 transforms.ColorJitter

功能:调整亮度、对比度、饱和度和色相
在这里插入图片描述

• brightness:亮度调整因子
  当为a时,从[max(0, 1-a), 1+a]中随机选择
  当为(a, b)时,从[a, b]中
• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数,当为a时,从[-a, a]中选择参数,注: 0<= a <= 0.5
        当为(a, b)时,从[a, b]中选择参数,注:-0.5 <= a <= b <= 0.5

4.3 Grayscale和RandomGrayscale

功能:依概率将图片转换为灰度图
在这里插入图片描述
• num_ouput_channels:输出通道数只能设1或3
• p:概率值,图像被转换为灰度图的概率

4.4 RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
在这里插入图片描述
在这里插入图片描述
• degrees:旋转角度设置
• translate:平移区间设置,如(a, b), a设置宽(width),b设置高(height)
    图像在宽维度平移的区间为 -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置

4.5 RandomErasing

功能:对图像进行随机遮挡
在这里插入图片描述

• p:概率值,执行该操作的概率
• scale:遮挡区域的面积
• ratio:遮挡区域长宽比
• value:设置遮挡区域的像素值,(R, G, B) or (Gray)

五、transforms的操作

5.1 transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个

transforms.RandomChoice([transforms1, transforms2, transforms3])

5.2 transforms.RandomApply

功能:依据概率执行一组transforms操作

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

5.3 transforms.RandomOrder

功能:对一组transforms操作打乱顺序

transforms.RandomOrder([transforms1, transforms2, transforms3])

六、自定义transforms

6.1 自定义transforms要素

1.仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
当前transforms的输入是上一个transforms的输出,所以要保证数据类型匹配:
在这里插入图片描述

6.2 通过类实现多参数传入

在这里插入图片描述

在Python中,__call__是一个特殊的方法,用于使一个对象可以像函数一样被调用。如果一个类定义了__call__方法,那么实例化的对象就可以被当作函数一样调用,而调用的实际上是__call__方法。

class CallableClass:def __init__(self):print("Initializing the CallableClass")def __call__(self, *args, **kwargs):print("Calling the CallableClass with arguments:", args, kwargs)# 实例化对象
obj = CallableClass()# 调用对象,实际上调用了__call__方法
obj(1, 2, 3, keyword_arg="hello")

上面的例子中,CallableClass定义了__call__方法,这意味着实例obj可以像函数一样被调用。当你调用obj(1, 2, 3, keyword_arg=“hello”)时,实际上是在调用obj.call(1, 2, 3, keyword_arg=“hello”)。

6.3 椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点, 白点称为盐噪声,黑色为椒噪声
信噪比(Signal-Noise Rate, SNR)是衡量噪声的比例,图像中为图像像素的占比,从下图可以看出,信噪比越小,图片丢失的像素越多。
在这里插入图片描述

6.4 自定义transforms代码实现

class AddPepperNoise(object):"""增加椒盐噪声Args:snr (float): Signal Noise Rate 信噪比p (float): 概率值,依概率执行该操作Attributes:snr (float): 信噪比p (float): 操作执行的概率"""def __init__(self, snr, p=0.9):# 确保传入的snr和p是float类型assert isinstance(snr, float) and isinstance(p, float)self.snr = snrself.p = pdef __call__(self, img):"""对图像应用椒盐噪声操作。Args:img (PIL Image): PIL Image对象Returns:PIL Image: 处理后的PIL Image对象"""# 根据概率决定是否执行噪声操作if random.uniform(0, 1) < self.p:img_ = np.array(img).copy()h, w, c = img_.shapesignal_pct = self.snrnoise_pct = (1 - self.snr)# 生成噪声掩码,表示每个像素是原始图像、盐噪声还是椒噪声mask = np.random.choice((0, 1, 2), size=(h, w, 1),p=[signal_pct, noise_pct / 2., noise_pct / 2.])mask = np.repeat(mask, c, axis=2)# 根据噪声类型修改图像像素值img_[mask == 1] = 255  # 盐噪声img_[mask == 2] = 0    # 椒噪声# 将NumPy数组转换回PIL Image对象,并确保数据类型为uint8,颜色通道为RGBreturn Image.fromarray(img_.astype('uint8')).convert('RGB')else:return img

在这里插入图片描述

七、数据增强策略

原则:让训练集与测试集更接近可以使用下面这些方法
• 空间位置:平移
• 色彩:灰度图,色彩抖动
• 形状:仿射变换
• 上下文场景:遮挡,填充

例如我们训练集白猫比较多,可以改变白猫色彩,让白猫的颜色接近黑猫。
在这里插入图片描述

数据增强代码实现

要求:使用第四套RMB进行训练,要求能对第5套RMB识别正确。

我们只进行普通的图片处理训练好的模型,发现将第五套100元都识别成一元,因为第四套人民币的1元和第五套人民的100元颜色相近,所以会导致识别错误:
在这里插入图片描述
解决方法,将所有训练集颜色都进行灰度处理,代码修改如下:

train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),  #图片灰度化transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])

修改后的预测结果如下:
在这里插入图片描述
训练完整代码如下:

# -*- coding: utf-8 -*-import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from lenet import LeNet
from my_dataset import RMBDataset
from common_tools import transform_invertdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss_val)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()# ============================ inference ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)for i, data in enumerate(valid_loader):# forwardinputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)rmb = 1 if predicted.numpy()[0] == 0 else 100img_tensor = inputs[0, ...]  # C H Wimg = transform_invert(img_tensor, train_transform)plt.imshow(img)plt.title("LeNet got {} Yuan".format(rmb))plt.show()plt.pause(0.5)plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/228727.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[MySQL] MySQL 高级(进阶) SQL 语句

一、高效查询方式 1.1 指定指字段进行查看 事先准备好两张表 select 字段1&#xff0c;字段2 from 表名; 1.2 对字段进行去重查看 SELECT DISTINCT "字段" FROM "表名"; 1.3 where条件查询 SELECT "字段" FROM 表名" WHERE "条件…

两种方法求解平方根 -- 牛顿法、二分法

Leetcode相关题目&#xff1a; 69. x 的平方根 以求解 a a a 的平方根为例&#xff0c;可转换为求解方程 f ( x ) f(x) f(x)的根。 f ( x ) x 2 − a f(x)x^2-a f(x)x2−a 牛顿法迭代公式如下&#xff1a; x n 1 x n − f ( x n ) f ′ ( x n ) x_{n1} x_n - \frac {f…

华为交换机生成树STP配置案例

企业内部网络怎么防止网络出现环路&#xff1f;学会STP生成树技术就可以解决啦。 STP简介 在二层交换网络中&#xff0c;一旦存在环路就会造成报文在环路内不断循环和增生&#xff0c;产生广播风暴&#xff0c;从而占用所有的有效带宽&#xff0c;使网络变得无法正常通信。 在…

鸿蒙 DevEco Studio 3.1 入门指南

本文主要记录开发者入门&#xff0c;从软件安装到项目运行&#xff0c;以及后续的学习 1&#xff0c;配置开发环境 1.1 下载安装包 官网下载链接 点击立即下载找到对应版版本 下载完成&#xff0c;按照提示默认安装即可 1.2 下载SDK及工具链 运行已安装的DevEco Studio&…

odoo17 | 创建一个新应用程序

前言 本章的目的是为创建一个全新的Odoo模块奠定基础。 我们将从头开始&#xff0c;以使我们的模块被Odoo识别所需的最低限度。 在接下来的章节中&#xff0c;我们将逐步添加功能以构建一个真实的业务案例。 教程 假设我门需要在odoo上开发一个新app模块例如房地产广告模块。…

uniapp Vue3 日历 可签到 跳转

上干货 <template><view class"zong"><view><view class"top"><!-- 上个月 --><view class"sgy" click"sgy">◀</view><view class"nianyue">{{ year }}年{{ month 1 }}…

MD5算法

一、引言 MD5&#xff08;Message-Digest Algorithm 5&#xff09;是一种广泛应用的密码散列算法&#xff0c;由Ronald L. Rivest于1991年提出。MD5算法主要用于对任意长度的消息进行加密&#xff0c;将消息压缩成固定长度的摘要&#xff08;通常为128位&#xff09;。在密码学…

右键菜单“以notepad++打开”,在windows文件管理器中

notepad 添加到文件管理器的右键菜单中 找到安装包&#xff0c;重新安装一般即可。 这里有最新版&#xff1a;地址 密码:f0f1 方法 在安装的时候勾选 “Context Menu Entry” 即可 Notepad的右击打开文件功能 默认已勾选 其作用是添加右键快捷键。即&#xff0c;对于任何…

黑马程序员SSM框架-SpringBoot

视频连接&#xff1a;SpringBoot-01-SpringBoot工程入门案例开发步骤_哔哩哔哩_bilibili SpringBoot简介 入门程序 也可以基于官网创建项目。 SpringBoot项目快速启动 下面的插件将项目运行所需的依赖jar包全部加入到了最终运行的jar包中&#xff0c;并将入口程序指定。 Spri…

2023/12/30 c++ work

定义一个Person类&#xff0c;私有成员int age&#xff0c;string &name&#xff0c;定义一个Stu类&#xff0c;包含私有成员double *score&#xff0c;写出两个类的构造函数、析构函数、拷贝构造和拷贝赋值函数&#xff0c;完成对Person的运算符重载(算术运算符、条件运算…

JavaScript 工具库 | PrefixFree给CSS自动添加浏览器前缀

新版的CSS拥有多个新属性&#xff0c;而标准有没有统一&#xff0c;有的浏览器厂商为了吸引更多的开发者和用户&#xff0c;已经加入了最新的CSS属性支持&#xff0c;这其中包含了很多炫酷的功能&#xff0c;但是我们在使用的时候&#xff0c;不得不在属性前面添加这些浏览器的…

lv14 注册字符设备 3

1 注册字符设备 1.1 结构体介绍 struct cdev {struct kobject kobj;//表示该类型实体是一种内核对象struct module *owner;//填THIS_MODULE&#xff0c;表示该字符设备从属于哪个内核模块const struct file_operations *ops;//指向空间存放着针对该设备的各种操作函数地址str…

VMware安装RHEL9.0版本Linux系统

最近在学习Linux&#xff0c;安装了Red Hat Enterprise Linux 的 9.0版本&#xff0c;简称RHEL9.0。RHEL9.0是Red Hat公司发布的面向企业用户的Linux操作系统的最新版本。我把它安装在虚拟机VMware里来减少电脑性能占用&#xff0c;也防止系统炸搞得我后面要重装。安装RHEL9.0还…

2023海内外零知识证明学习资料汇总(二)(深入理解零知识证明篇)

工欲善其事,必先利其器 Web3开发中&#xff0c;各种工具、教程、社区、语言框架.。。。 种类繁多&#xff0c;是否有一个包罗万象的工具专注与Web3开发和相关资讯能毕其功于一役&#xff1f; 参见另一篇博文&#x1f449; 2024最全面且有知识深度的web3开发工具、web3学习项目…

Springboot整合JSP-修订版本(Springboot3.1.6+IDEA2022版本)

1、问题概述&#xff1f; Springboot对Thymeleaf支持的要更好一些&#xff0c;Springboot内嵌的Tomcat默认是没有JSP引擎&#xff0c;不支持直接使用JSP模板引擎。这个时候我们需要自己配置使用。 2、Springboot整合使用JSP过程 现在很多的IDEA版本即使创建的项目类型是WAR工…

中科亿海微UART协议

引言 在现代数字系统设计中&#xff0c;通信是一个至关重要的方面。而UART&#xff08;通用异步接收器/发送器&#xff09;协议作为一种常见的串行通信协议&#xff0c;被广泛应用于各种数字系统中。FPGA&#xff08;现场可编程门阵列&#xff09;作为一种灵活可编程的硬件平台…

王道考研计算机网络——应用层

如何为用户提供服务&#xff1f; CS/P2P 提高域名解析的速度&#xff1a;local name server高速缓存&#xff1a;直接地址映射/低级的域名服务器的地址 本机也有告诉缓存&#xff1a;本机开机的时候从本地域名服务器当中下载域名和地址的对应数据库&#xff0c;放到本地的高…

Python编程新技能:如何优雅地实现水仙花数?

水仙花数&#xff08;Narcissistic number&#xff09;也被称为阿姆斯特朗数&#xff08;Armstrong number&#xff09;或自恋数等&#xff0c;它是一个非负整数&#xff0c;其特性是该数的每个位上的数字的n次幂之和等于它本身&#xff0c;其中n是该数的位数。简单来说&#x…

【HarmonyOS开发】案例-记账本开发

OpenHarmony最近一段时间&#xff0c;简直火的一塌糊度&#xff0c;学习OpenHarmony相关的技术栈也有一段时间了&#xff0c;做个记账本小应用&#xff0c;将所学知识点融合记录一下。 1、记账本涉及知识点 基础组件&#xff08;Button、Select、Text、Span、Divider、Image&am…

SpringCloud(H版alibaba)框架开发教程,使用eureka,zookeeper,consul,nacos做注册中心——附源码(1)

源码地址&#xff1a;https://gitee.com/jackXUYY/springboot-example 创建订单服务&#xff0c;支付服务&#xff0c;公共api服务&#xff08;共用的实体&#xff09;&#xff0c;eureka服务 1.cloud-consumer-order80 2.cloud-provider-payment8001 3.cloud-api-commons 4.…