pytorch基础模块:Tensorboard、Dataset、Transforms、Dataloader

Tensorboard、Dataset、Transforms、Dataloader

该文档主要参考【土堆】的视频教程:pytorch入门教程–土堆

一、Tensorboard

安装tensorboardpip install tensorboard

使用步骤

  • 引入相关库:from torch.utils.tensorboard import SummaryWriter
  • 构建SummaryWriter对象:writer = SummaryWriter(log_dir="logs")
    • 在工程目录下创建一个名为logs的文件夹,用于存放Tensorboard绘图所用的文件
  • 打开tensorboard
    • 命令行执行:tensorboard --logdir=logs
      • 如果有错误,使用logs的绝对地址
    • 点击链接,即可查看

[常用函数]

add_scalaradd_imagesadd_graph

1.1、add_scalar

功能:添加标量数据(例如记录训练epoch及对应的loss)

常用参数:

  • 标题:tag
  • 标量数据(Y轴):scalar_value
  • 计步数据(X轴):global_step
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter(log_dir="logs")
for i in range(100):writer.add_scalar(tag="y=3x", scalar_value=3 * i, global_step=i)

注意:使用同一个SummaryWriter对象且tag相同时,会绘制在同一幅图上,为避免该情况可以删除logs中的内容,或者每次都新建文件夹(log_dir="logs_1"

图片如下(左图为只绘制一次的结果y=3x,右图为在同一幅图上分别绘制y=2x以及y=10x的结果)
在这里插入图片描述

在这里插入图片描述

1.2、add_image

功能:添加图片数据(例如记录每个batch的输入图片)

常用参数:

标题:tag

图片数据:img_tensor(一般是torch.Tensor或者numpy.array类型)

计步数据:global_step

图片格式:dataformats(CHW, HWC, HW, WH等,默认为'CHW')

from torch.utils.tensorboard import SummaryWriter
import os
import cv2project_path = os.getcwd()
file_name_1 = 'dog_1.png'
file_name_2 = 'dog_2.png'file_path = os.path.join(project_path, r'data\dog')
full_file_path_1 = os.path.join(file_path, file_name_1)
full_file_path_2 = os.path.join(file_path, file_name_2)
# 使用cv2(即OpenCV库)读取图片时,图片通常是以HWC(高度、宽度、通道)格式存储的,并且每个像素的颜色值(对于RGB图像)都是0到255之间的整数
image_data_1 = cv2.imread(full_file_path_1)
image_data_2 = cv2.imread(full_file_path_2)
# 将图片从BGR转换为RGB(因为cv2默认通道顺序为BGR,使用PIL读取图片的通道顺序为RGB):如果不进行调整,则图片颜色会失真
image_data_1 = cv2.cvtColor(image_data_1, cv2.COLOR_BGR2RGB)
image_data_2 = cv2.cvtColor(image_data_2, cv2.COLOR_BGR2RGB)writer = SummaryWriter(log_dir="logs")
writer.add_image(tag='dog', img_tensor=image_data_1, global_step=0, dataformats='HWC')
writer.add_image(tag='dog', img_tensor=image_data_2, global_step=1, dataformats='HWC')
writer.close()

图片如下(通过拖动进度条可以查看不同step对应的图片)
在这里插入图片描述
在这里插入图片描述

1.3、add_graph

功能:添加模型结构数据(例如记录神经网络的结构)

常用参数:

模型:model(可以构建自己的模型或者使用公开的经典模型,例如VGG-16

模型输入:input_to_model(要求图片数据是Tensor类型)

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import cv2
from torchvision import transformsif __name__ == '__main__':# 使用内置的分类模型VGG-16(后续会更新如何搭建模型的文章)vgg_16 = torchvision.models.vgg16(progress=False)print(vgg_16)image = cv2.imread('data\\dog\\dog_1.png')# 图片类型转换为Tensor并调整尺寸为224*224(vgg16的标准输入尺寸)image = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])(image)# 添加一个批次维度(模型通常期望输入具有批次维度),使得形状从(C, H, W)变为(1, C, H, W),其中1表示批次大小。image = torch.unsqueeze(image, 0)writer = SummaryWriter('model')# image提供模型在前向传播过程中所需的输入数据,TensorBoard据此生成模型的计算图writer.add_graph(model=vgg_16, input_to_model=image)writer.close()

在这里插入图片描述
在这里插入图片描述

二、Dataset

2.1、使用公开数据集

  • 常用的数据集:MNISTCIFAR10

  • 这些封装好的数据集都继承了torch.utils.data中的Dataset类,该类有两个重要的方法:getitem()len()

  • 可以通过参数transform以及target_transform在加载数据时进行实时的数据增强操作(如旋转、裁剪、缩放等);

    • 对图像数据的增强操作详见章节三Transforms
  • 可以通过继承Dataset类并重写getitem()len()方法创建自己的数据集类(使用自己的数据)

from torchvision import datasetsif __name__ == '__main__':# 指定数据集路径(下载好的数据集会自动解压到路径下)data_path = 'common_dataset'# train=True表示为训练集,download=True表示下载数据集(若已经下载好则自动加载本地数据集)# 若在线下载速度慢,可进入CIFAR10类中,直接通过数据集的下载链接下载(下载好放在data_path下即可)train_data = datasets.CIFAR10(root=data_path, train=True, download=True)test_data = datasets.CIFAR10(root=data_path, train=False, download=True)# 打印数据集所包含的数据个数print(len(train_data), len(test_data))# 获取第一个数据的图片(PIL Image类型)及标签(类别)img, label = train_data[0]# 打印类别索引及真实的类别print(label, train_data.classes[label])img.show()

2.2、使用自己的数据

from torch.utils.data import Dataset
import cv2
import osclass MyDataset(Dataset):def __init__(self, data_path, label):# super.__init__()self.data_path = data_pathself.label = labelself.full_path = os.path.join(self.data_path, self.label)self.images_name = os.listdir(self.full_path)def __getitem__(self, item):image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))# BGR转换为RGB,不然会失真image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)return image_data, self.labeldef __len__(self):return len(self.images_name)if __name__ == '__main__':data_path = os.path.join(os.getcwd(), 'data')label = 'dog'dataset_instance = MyDataset(data_path, label)print(len(dataset_instance))image, label = dataset_instance[0]print(image.shape, label)print(type(image))

三、Transforms

Transforms是用于处理图片的库,内置的类基本可以满足图片处理的需求,例如图片类型转换(PIL Imagendarraytensor)、尺寸调整、裁剪等

  • 若没有torchvision则需要先安装:pip install torchvision

  • [常用功能(类)]

    ToTensorNormalizeResizeRandomCropCompose

  • 使用方式:根据需求选择类,创建类的实例,使用类的实例完成图片处理

3.1、ToTensor

  • 功能:将PIL Imagendarray类型的图片转换为Tensor类型(Convert a PIL Image or ndarray to tensor and scale the values accordingly

  • 输入:PIL Imagendarray类型的图片(PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]

  • 输出:Tensor类型,shapeCHW,每个元素均为[0.0, 1.0]之间的数(torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import cv2
import os# 创建Dataset的子类
class MyDataset(Dataset):def __init__(self, data_path, label):# super.__init__()self.data_path = data_pathself.label = labelself.full_path = os.path.join(self.data_path, self.label)self.images_name = os.listdir(self.full_path)def __getitem__(self, item):image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))# BGR转换为RGBimage_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)return image_data, self.labeldef __len__(self):return len(self.images_name)if __name__ == '__main__':data_path = os.path.join(os.getcwd(), 'data')label = 'dog'# 创建dataset子类实例,用于读取图片dataset_instance = MyDataset(data_path, label)# 输出图片数量print(len(dataset_instance))image, label = dataset_instance[0]# 根据索引获取图片print(image.shape, label)writer = SummaryWriter(log_dir='transforms_logs')# 使用ToTensorto_tensor = transforms.ToTensor()image_tensor = to_tensor(image)writer.add_image(tag='dog', img_tensor=image_tensor, global_step=0)

3.2、Normalize

  • 功能:对每一个通道(channel)分别根据其均值、标准差进行标准化(Normalize a tensor image with mean and standard deviation

  • 输入:Tensor类型的图片(This transform does not support PIL Image

  • 输出:标准化后的Tensor类型的图片,output[channel] = (input[channel] - mean[channel]) / std[channel]

# 使用Normalize
# 创建对象的时候给定均值、标准差
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
image_tensor = normalize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=1)

3.3、Resize

  • 功能:调整图片HWResize the input image to the given size

    • size为序列(例如size=[500, 800]),则调整后的图片H=500W=800

    • size为整数(size=500),则根据HW中较小的值确定调整后的尺寸

      • 例如H=600W=1200,则调整后的H=500,调整后的W 1200 / 600 ∗ 500 = 1000 1200/600*500=1000 1200/600500=1000
  • 输入:PIL ImageTensor类型的图片

  • 输出:与输入的类型相同

# 使用Resize
resize = transforms.Resize(size=(500, 1000))
image_tensor = resize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=2)

3.4、RandomCrop

  • 功能:对图片进行随机裁剪(Crop the given image at a random location
    • size为序列(例如size=[500, 800]),则裁剪后的图片H=500W=800
    • size为整数(size=500),则裁剪后的图片H=500W=500
    • 裁剪后的图片H、W均不大于原有图片的H、W,否则会报错
  • 输入:PIL ImageTensor类型的图片
  • 输出:与输入的类型相同
# 使用RandomCrop
random_crop = transforms.RandomCrop(size=(300, 800))
image_tensor = random_crop(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=3)

3.5、结果展示

从上到下分别是NormalizeResizeRandomCrop顺序执行后的结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.6、Compose

  • 功能:指定一系列图片处理步骤,对图片进行流式处理(Composes several transforms together

    • 使用列表指定需要对图片进行的处理,[transform_1, transform_2,...]
  • 输入:PIL ImagendarrayTensor类型的图片

  • 输出:由图片类型、指定的处理步骤决定

# 使用
compose = transforms.Compose([to_tensor, normalize, resize, random_crop])
image_tensor = compose(image_tensor)
writer.add_image(tag='dog_compose', img_tensor=image_tensor, global_step=0)

四、Dataloader

Dataloader用于批量加载和处理数据,能数据集分成小批量,并在训练过程中按需加载这些小批量数据,以提高训练效率并节省内存。

  • 批量加载数据:参数batch_size,每次加载batch_size个数据,而不是一次性加载整个数据集;
  • 数据“洗牌”:参数shuffle,在每个训练周期开始时随机打乱数据顺序,防止模型过拟合;
  • 并行处理:参数num_workers,利用多个线程或进程加快数据加载过程;
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterif __name__ == '__main__':data_path = 'common_dataset'# 为了方便使用tensorboard进行展示,使用transform=transforms.ToTensor()将图片由PIL类型转换为Tensor类型train_data = datasets.CIFAR10(root=data_path, train=True, transform=transforms.ToTensor(), download=True)test_data = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor(), download=True)# 打印数据集所包含的数据个数print(len(train_data), len(test_data))# 获取第一个数据的图片及标签(类别)img, label = train_data[0]# 打印类别索引及真实的类别print(label, train_data.classes[label])# img.show()writer = SummaryWriter(log_dir='CIFAR10_logs')# dataloader示例# drop_last=True可以舍弃最后的不足一批(batch_size)的图片data_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, drop_last=False)# 指定训练的最大epochmax_epoch = 1for epoch in range(max_epoch):i = 0for images, labels in data_loader:writer.add_images(tag=f'CIFAR10_{epoch}', img_tensor=images, global_step=i)i += 1writer.close()

tensorboard中部分batch的图片如下:
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/390776.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LinkedList接口源码解读

LinkedList 接口源码解读(一) 前言 因为追求质量,所以写的较慢。大概在接下来的三天内会把LinkedList源码解析出完。大概还有两篇文章。废话不多说,正片开始! 大家都知道,LinkedList是在Java底层中是由 …

手机上音乐如何转换成MP3格式?分享5款音频格式转换APP

手机上音乐如何转换成MP3格式?相信很多外出办公或者不经常使用电脑的工作人士,学生党,媒体从业者都有这样的疑惑和需求。不同设备和应用可能支持不同的音频格式,导致某些情况下需要将音乐文件转换为MP3格式以确保兼容性。下面&…

操作系统|day4.Linux、Linux内核、Linux负载、Linux文件存储

文章目录 LinuxLinux内核定义功能态 Linux负载定义 Linux文件存储链接分类区别使用场景 拷贝 Linux Linux内核 定义 内核是操作系统的核心,具有很多最基本功能,它负责管理系统的进程、内存、设备驱动程序、文件和网络系统,决定着系统的性能…

.NET周刊【7月第4期 2024-07-28】

国内文章 .NET 高性能缓冲队列实现 BufferQueue https://mp.weixin.qq.com/s/fUhJpyPqwcmb3whuV3CDyg BufferQueue 是一个用 .NET 编写的高性能的缓冲队列实现,支持多线程并发操作。 项目地址:https://github.com/eventhorizon-cli/BufferQueue 项目…

【虚拟仿真】Unity3D中实现2DUI显示在3D物体旁边

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群:398291828大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 这篇文章来实现2DUI显示在3D物体旁边,当我们需要在3D模型旁边显示2DUI的时候,比如人物的对…

grep工具的使用

grep [options]…… pattern [file]…… 工作方式: grep 在一个或者多个文件中搜索字符串模板,如果模板中包括空格,需要使用引号引起来,模 板后的所有字符串会被看作是文件名。 工作结果:如果模板搜索成功&#xf…

Springboot+Vue在线水果(销售)商城管理系统-附源码与配套论文

1.1 研究背景 互联网概念的产生到如今的蓬勃发展,用了短短的几十年时间就风靡全球,使得全球各个行业都进行了互联网的改造升级,标志着互联网浪潮的来临。在这个新的时代,各行各业都充分考虑互联网是否能与本行业进行结合&#xf…

Java:进程和线程

文章目录 进程线程的概念和区别总结如何创建线程1.继承Thread重写run2.实现Runnable重写run3.继承Thread重写run,通过匿名内部类来实现4. 实现Runnable重写run,通过匿名内部类来实现5.基于lambda表达式来创建 虚拟线程 并发编程: 通过写特殊的代码,把多个CPU核心都利…

Shell编程 --基础语法(1)

文章目录 Shell编程基础语法变量定义变量使用变量命令的使用只读变量删除变量 传递参数字符串获取字符串长度字符串截取 数组定义方式关联数组获取数组的长度 总结 Shell编程 Shell是一种程序设计语言。作为命令语言,它交互式解释和执行用户输入的命令或者自动地解…

【人工智能基础三】卷积神经网络基础(CNN)

文章目录 1. 卷积神经网络结构2. 卷积神经网络计算2.1. 卷积层计算2.2. 池化层计算2.3. 全连接层计算 3. 典型卷积神经网络3.1. AlexNet3.2. VGGnet 卷积神经网络(Convolutional Neural Network,CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Ne…

计算机毕业设计Python+Tensorflow股票推荐系统 股票预测系统 股票可视化 股票数据分析 量化交易系统 股票爬虫 股票K线图 大数据毕业设计 AI

1、用pycharm打开项目,一定要打开包含manage.py文件所在文件夹 2、配置解释器:建议使用Anaconda(Python 3.8(base)),低于3.8版本的,页面会不兼容 3、安装依赖库:打开pycharm的终端,输入: pip in…

Docker-学习笔记(借助宝塔面板)

ubuntu环境 一、安装 可以参考官网进行或其他博客进行安装 1.进入宝塔面板 进图Docker菜单,查看是否提示安装。 2.查看是否安装 查看版本 docker -v 证明已经安装 二、常用命令 1.查看版本 docker -v 2.启动、停止、重启docker systemctl start docker…

自制安卓车机软件(含APP)

本软件使用APPinventor2编程软件,耗时5天和3天调试,具有高德导航,视频播放,网易云音乐,酷狗,抖音,(需下载车机版软件)和自定义添加软件,网页有哔哩哔哩&#…

无人机工程师技术高级证书详解

随着无人机技术的飞速发展,其在航拍、农业、测绘、救援、物流等多个领域的应用日益广泛,对无人机工程师的专业技能与综合素质提出了更高要求。无人机工程师技术高级证书,作为对无人机领域高级工程师专业技能的权威认证,不仅是对个…

简单搭建dns服务器

目录 一.安装服务 二.编写子配置文件 三.编写主配置文件 四.编写文件 五.测试 一.安装服务 [rootnode1 ~]# dnf install bind -y 二.编写子配置文件 [rootnode1 ~]# vim /etc/named.rfc1912.zones 三.编写主配置文件 [rootnode1 ~]# vim /etc/named.conf 四.编写文件 …

【Python】Numpy概述安装及使用

文章目录 Numpy概述Numpy开发环境搭建Numpy使用创建数组创建一维数组创建二维数组创建三维数组,array()函数ndmin参数的使用array()函数dtype参数的使用随机数创建 Numpy概述 Numpy是科学计算基础库,提供大量科学计算相关功能,比如数据统计&…

GuLi商城-商品服务-API-新增商品-调试会员等级相关接口

在网关服务中配置路由: 代码: nacos这些服务都要启动: 如果有不是一个命名空间中的,要改成同一个命名空间中 启动商品product服务遇到循环依赖问题,解决:

Leetcode 第 135 场双周赛题解

Leetcode 第 135 场双周赛题解 Leetcode 第 135 场双周赛题解题目1:3222. 求出硬币游戏的赢家思路代码复杂度分析 题目2:3223. 操作后字符串的最短长度思路代码复杂度分析 题目3:3224. 使差值相等的最少数组改动次数思路代码复杂度分析 题目4…

classical Chinese

classical Chinese 中型娃娃暑假作业背诵 文言文《伯牙鼓琴》 1)拿到文言文,先看一遍 2)用白话文(现代文)翻译一次 3)用白话文对照回去文言文(白话文中那些需要替换回文言文呢) 虽…

神奇海洋养鱼小程序游戏广告联盟流量主休闲小游戏源码

在海洋养鱼小程序中,饲料、任务系统、系统操作日志、签到、看广告、完成喂养、每日签到、系统公告、积分商城、界面设计、拼手气大转盘抽奖以及我的好友等功能共同构建了一个丰富而互动的游戏体验。以下是对这些功能的进一步扩展介绍: 饲料 任务奖励&a…