Tensorboard、Dataset、Transforms、Dataloader
该文档主要参考【土堆】的视频教程:pytorch入门教程–土堆
一、Tensorboard
安装tensorboard
:pip install tensorboard
使用步骤:
- 引入相关库:
from torch.utils.tensorboard import SummaryWriter
- 构建
SummaryWriter
对象:writer = SummaryWriter(log_dir="logs")
- 在工程目录下创建一个名为
logs
的文件夹,用于存放Tensorboard
绘图所用的文件
- 在工程目录下创建一个名为
- 打开
tensorboard
- 命令行执行:
tensorboard --logdir=logs
- 如果有错误,使用
logs
的绝对地址
- 如果有错误,使用
- 点击链接,即可查看
- 命令行执行:
[常用函数]
add_scalar
、add_images
、add_graph
1.1、add_scalar
功能:添加标量数据(例如记录训练epoch及对应的loss)
常用参数:
- 标题:
tag
- 标量数据(Y轴):
scalar_value
- 计步数据(X轴):
global_step
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter(log_dir="logs")
for i in range(100):writer.add_scalar(tag="y=3x", scalar_value=3 * i, global_step=i)
注意:使用同一个SummaryWriter
对象且tag
相同时,会绘制在同一幅图上,为避免该情况可以删除logs
中的内容,或者每次都新建文件夹(log_dir="logs_1"
)
图片如下(左图为只绘制一次的结果y=3x
,右图为在同一幅图上分别绘制y=2x
以及y=10x
的结果)
1.2、add_image
功能:添加图片数据(例如记录每个batch的输入图片)
常用参数:
标题:tag
图片数据:img_tensor
(一般是torch.Tensor
或者numpy.array
类型)
计步数据:global_step
图片格式:dataformats(CHW, HWC, HW, WH等,默认为'CHW')
from torch.utils.tensorboard import SummaryWriter
import os
import cv2project_path = os.getcwd()
file_name_1 = 'dog_1.png'
file_name_2 = 'dog_2.png'file_path = os.path.join(project_path, r'data\dog')
full_file_path_1 = os.path.join(file_path, file_name_1)
full_file_path_2 = os.path.join(file_path, file_name_2)
# 使用cv2(即OpenCV库)读取图片时,图片通常是以HWC(高度、宽度、通道)格式存储的,并且每个像素的颜色值(对于RGB图像)都是0到255之间的整数
image_data_1 = cv2.imread(full_file_path_1)
image_data_2 = cv2.imread(full_file_path_2)
# 将图片从BGR转换为RGB(因为cv2默认通道顺序为BGR,使用PIL读取图片的通道顺序为RGB):如果不进行调整,则图片颜色会失真
image_data_1 = cv2.cvtColor(image_data_1, cv2.COLOR_BGR2RGB)
image_data_2 = cv2.cvtColor(image_data_2, cv2.COLOR_BGR2RGB)writer = SummaryWriter(log_dir="logs")
writer.add_image(tag='dog', img_tensor=image_data_1, global_step=0, dataformats='HWC')
writer.add_image(tag='dog', img_tensor=image_data_2, global_step=1, dataformats='HWC')
writer.close()
图片如下(通过拖动进度条可以查看不同step对应的图片)
1.3、add_graph
功能:添加模型结构数据(例如记录神经网络的结构)
常用参数:
模型:model
(可以构建自己的模型或者使用公开的经典模型,例如VGG-16
)
模型输入:input_to_model
(要求图片数据是Tensor
类型)
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import cv2
from torchvision import transformsif __name__ == '__main__':# 使用内置的分类模型VGG-16(后续会更新如何搭建模型的文章)vgg_16 = torchvision.models.vgg16(progress=False)print(vgg_16)image = cv2.imread('data\\dog\\dog_1.png')# 图片类型转换为Tensor并调整尺寸为224*224(vgg16的标准输入尺寸)image = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])(image)# 添加一个批次维度(模型通常期望输入具有批次维度),使得形状从(C, H, W)变为(1, C, H, W),其中1表示批次大小。image = torch.unsqueeze(image, 0)writer = SummaryWriter('model')# image提供模型在前向传播过程中所需的输入数据,TensorBoard据此生成模型的计算图writer.add_graph(model=vgg_16, input_to_model=image)writer.close()
二、Dataset
2.1、使用公开数据集
-
常用的数据集:
MNIST
、CIFAR10
等 -
这些封装好的数据集都继承了
torch.utils.data
中的Dataset
类,该类有两个重要的方法:getitem()
、len()
; -
可以通过参数
transform
以及target_transform
在加载数据时进行实时的数据增强操作(如旋转、裁剪、缩放等);- 对图像数据的增强操作详见章节三
Transforms
- 对图像数据的增强操作详见章节三
-
可以通过继承
Dataset
类并重写getitem()
、len()
方法创建自己的数据集类(使用自己的数据)
from torchvision import datasetsif __name__ == '__main__':# 指定数据集路径(下载好的数据集会自动解压到路径下)data_path = 'common_dataset'# train=True表示为训练集,download=True表示下载数据集(若已经下载好则自动加载本地数据集)# 若在线下载速度慢,可进入CIFAR10类中,直接通过数据集的下载链接下载(下载好放在data_path下即可)train_data = datasets.CIFAR10(root=data_path, train=True, download=True)test_data = datasets.CIFAR10(root=data_path, train=False, download=True)# 打印数据集所包含的数据个数print(len(train_data), len(test_data))# 获取第一个数据的图片(PIL Image类型)及标签(类别)img, label = train_data[0]# 打印类别索引及真实的类别print(label, train_data.classes[label])img.show()
2.2、使用自己的数据
from torch.utils.data import Dataset
import cv2
import osclass MyDataset(Dataset):def __init__(self, data_path, label):# super.__init__()self.data_path = data_pathself.label = labelself.full_path = os.path.join(self.data_path, self.label)self.images_name = os.listdir(self.full_path)def __getitem__(self, item):image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))# BGR转换为RGB,不然会失真image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)return image_data, self.labeldef __len__(self):return len(self.images_name)if __name__ == '__main__':data_path = os.path.join(os.getcwd(), 'data')label = 'dog'dataset_instance = MyDataset(data_path, label)print(len(dataset_instance))image, label = dataset_instance[0]print(image.shape, label)print(type(image))
三、Transforms
Transforms
是用于处理图片的库,内置的类基本可以满足图片处理的需求,例如图片类型转换(PIL Image
、ndarray
、tensor
)、尺寸调整、裁剪等
-
若没有
torchvision
则需要先安装:pip install torchvision
-
[常用功能(类)]
ToTensor
、Normalize
、Resize
、RandomCrop
、Compose
-
使用方式:根据需求选择类,创建类的实例,使用类的实例完成图片处理
3.1、ToTensor
-
功能:将
PIL Image
、ndarray
类型的图片转换为Tensor
类型(Convert a PIL Image or ndarray to tensor and scale the values accordingly) -
输入:
PIL Image
、ndarray
类型的图片(PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]) -
输出:
Tensor
类型,shape
为CHW
,每个元素均为[0.0, 1.0]
之间的数(torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0])
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import cv2
import os# 创建Dataset的子类
class MyDataset(Dataset):def __init__(self, data_path, label):# super.__init__()self.data_path = data_pathself.label = labelself.full_path = os.path.join(self.data_path, self.label)self.images_name = os.listdir(self.full_path)def __getitem__(self, item):image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))# BGR转换为RGBimage_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)return image_data, self.labeldef __len__(self):return len(self.images_name)if __name__ == '__main__':data_path = os.path.join(os.getcwd(), 'data')label = 'dog'# 创建dataset子类实例,用于读取图片dataset_instance = MyDataset(data_path, label)# 输出图片数量print(len(dataset_instance))image, label = dataset_instance[0]# 根据索引获取图片print(image.shape, label)writer = SummaryWriter(log_dir='transforms_logs')# 使用ToTensorto_tensor = transforms.ToTensor()image_tensor = to_tensor(image)writer.add_image(tag='dog', img_tensor=image_tensor, global_step=0)
3.2、Normalize
-
功能:对每一个通道(
channel
)分别根据其均值、标准差进行标准化(Normalize a tensor image with mean and standard deviation) -
输入:
Tensor
类型的图片(This transform does not support PIL Image) -
输出:标准化后的
Tensor
类型的图片,output[channel] = (input[channel] - mean[channel]) / std[channel]
# 使用Normalize
# 创建对象的时候给定均值、标准差
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
image_tensor = normalize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=1)
3.3、Resize
-
功能:调整图片
H
、W
(Resize the input image to the given size)-
若
size
为序列(例如size=[500, 800]
),则调整后的图片H=500
,W=800
; -
若
size
为整数(size=500
),则根据H
、W
中较小的值确定调整后的尺寸- 例如
H=600
,W=1200
,则调整后的H=500
,调整后的W
为 1200 / 600 ∗ 500 = 1000 1200/600*500=1000 1200/600∗500=1000
- 例如
-
-
输入:
PIL Image
、Tensor
类型的图片 -
输出:与输入的类型相同
# 使用Resize
resize = transforms.Resize(size=(500, 1000))
image_tensor = resize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=2)
3.4、RandomCrop
- 功能:对图片进行随机裁剪(Crop the given image at a random location)
- 若
size
为序列(例如size=[500, 800]
),则裁剪后的图片H=500
,W=800
; - 若
size
为整数(size=500
),则裁剪后的图片H=500
,W=500
; - 裁剪后的图片H、W均不大于原有图片的H、W,否则会报错
- 若
- 输入:
PIL Image
、Tensor
类型的图片 - 输出:与输入的类型相同
# 使用RandomCrop
random_crop = transforms.RandomCrop(size=(300, 800))
image_tensor = random_crop(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=3)
3.5、结果展示
从上到下分别是Normalize
、Resize
、RandomCrop
顺序执行后的结果
3.6、Compose
-
功能:指定一系列图片处理步骤,对图片进行流式处理(Composes several transforms together)
- 使用列表指定需要对图片进行的处理,
[transform_1, transform_2,...]
- 使用列表指定需要对图片进行的处理,
-
输入:
PIL Image
、ndarray
、Tensor
类型的图片 -
输出:由图片类型、指定的处理步骤决定
# 使用
compose = transforms.Compose([to_tensor, normalize, resize, random_crop])
image_tensor = compose(image_tensor)
writer.add_image(tag='dog_compose', img_tensor=image_tensor, global_step=0)
四、Dataloader
Dataloader
用于批量加载和处理数据,能数据集分成小批量,并在训练过程中按需加载这些小批量数据,以提高训练效率并节省内存。
- 批量加载数据:参数
batch_size
,每次加载batch_size
个数据,而不是一次性加载整个数据集; - 数据“洗牌”:参数
shuffle
,在每个训练周期开始时随机打乱数据顺序,防止模型过拟合; - 并行处理:参数
num_workers
,利用多个线程或进程加快数据加载过程;
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterif __name__ == '__main__':data_path = 'common_dataset'# 为了方便使用tensorboard进行展示,使用transform=transforms.ToTensor()将图片由PIL类型转换为Tensor类型train_data = datasets.CIFAR10(root=data_path, train=True, transform=transforms.ToTensor(), download=True)test_data = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor(), download=True)# 打印数据集所包含的数据个数print(len(train_data), len(test_data))# 获取第一个数据的图片及标签(类别)img, label = train_data[0]# 打印类别索引及真实的类别print(label, train_data.classes[label])# img.show()writer = SummaryWriter(log_dir='CIFAR10_logs')# dataloader示例# drop_last=True可以舍弃最后的不足一批(batch_size)的图片data_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, drop_last=False)# 指定训练的最大epochmax_epoch = 1for epoch in range(max_epoch):i = 0for images, labels in data_loader:writer.add_images(tag=f'CIFAR10_{epoch}', img_tensor=images, global_step=i)i += 1writer.close()
tensorboard
中部分batch的图片如下: