Pytorch公共数据集、tensorboard、DataLoader使用

本文将主要介绍torchvision.datasets的使用,并以CIFAR-10为例进行介绍,对可视化工具tensorboard进行介绍,包括安装,使用,可视化过程等,最后介绍DataLoader的使用。希望对你有帮助

Pytorch公共数据集

torchvision.datasets.*
在这里插入图片描述
torchvision是pytorch的一个图形库,torchvision包由流行的数据集、模型架构和计算机视觉的通用图像转换组成。例如tensorboard、transfroms

在这里将主要介绍torchvision.datasets.*

在这里插入图片描述

在datasets中包含了许多公共的应用于图像领域的数据集。包含:图像分类、图像检测或分割、光流法、立体声匹配等

在本章当中,将以图像分类领域的CIFAR10数据集作为torchvision.datasets的例子进行介绍,因为他比较小,下载比较快。

CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片。

每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

下面是数据集中的类,以及每个类的10张随机图像

在这里插入图片描述

参数介绍

这些数据集的参数也是大同小异,由于CIFAR10数据集较小,下载就快。大家可以触类旁通

在这里插入图片描述

  • root :即指定数据集要下载在哪一个文件夹里面
  • train(bool):如果True即为训练集,否则False则为测试集
  • transform :进行图像变换的各种操作,如Resize、RandomCrop、Compose
  • target_transform :对于标签进行transform 操作
  • download :是否下载数据集,建议设置为True即可
import torch
import torchvision
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
#transform属性
trans_tool = torchvision.transforms.Compose([torchvision.transforms.ToTensor()  # 转为Tensor类型# torchvision.transforms.Resize((5, 5))  # 进行大小裁剪
])# 数据集划分
tran_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=trans_tool,download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans_tool,download=True)
print(tran_dataset[0])  
#Tensorboard
writer = SummaryWriter("logs")
for i in range(10):#显示测试集前10的图片img, label = tran_dataset[i]writer.add_image("CIFAR10",img,i)
writer.close()
Files already downloaded and verified
Files already downloaded and verified
(tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],[0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],[0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],...,[0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],[0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],[0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],[[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],[0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],[0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],...,[0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],[0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],[0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],[[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],[0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],[0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],...,[0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],[0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],[0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]]), 6)

利用tensorboard查看,在控制台输入即可:

tensorboard --logdir 目录

在这里插入图片描述

关于torchvision.datasets.CIFAR10介绍已经讲解完毕,后续内容为扩展内容,包括:tensorboard、DataLoader的使用

tensorboard可视化工具

torch.utils.tensorboard

在Pytorch发布后,网络及训练过程的可视化工具也相应地被开发出来,方便用户监督所建立模型的结构和训练过程

深度学习网络通常具有很深的层次结构,而且层与层之间通常会有并联、串联等连接方式,利用有效的工具将建立的深度学习网络结构有层次化的展示,这就需要使用相关的深度学习网络结构可视化库。

从Pytorch1.1之后,加入了tensorboard

一般安装新版的pytorch会自动安装,如果没安装,则在终端命令行下使用下面命令即可安装

pip install tensorboard
  • add_image()添加图片

  • add_scalar()添加标量数据

主要代码如下

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")  # 创建SummaryWriter,将运行结果存logs文件夹中
for i in range(100):writer.add_scalar("y=2x",2*i,i)  # 第一个参数相当于标题,第二个参数就相当于纵坐标的值,第三个参数就相当于横坐标的值
writer.close()

可视化操作

在终端输入:tensorboard --logdir 目录
在这里插入图片描述

访问:http://localhost:6006即可

在这里插入图片描述

writer.add_image的例子

from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriterimg_path = r"./pic.png"
# 打开一张图片
img = Image.open(img_path)
# 使用transforms对图像进行变换
# 实例化totensor对象
to_tens = transforms.ToTensor()
# 将pic变成Tensor类型的图片
tens_img = to_tens(img) # 自动调用call函数
#print(tens_img)# 使用上一篇文章中tensorboard进行查看
writer = SummaryWriter("transforms_logs")
writer.add_image("test_transforms",tens_img) # 标题,图像类型
writer.close()

DataLoader的使用

from torch.utils.data import DataLoader

torch的DataLoader主要是用来装载数据,就是给定已知的数据集,把数据集装载进DataLoaer,然后送入深度学习网络进行训练。

在torch.utils.data.DataLoader()参数中,只有dataset为必填参数,其他的均有默认值,下文介绍几个重要的参数

在这里插入图片描述

  • dataset:表示要读取的数据集

  • batch_size:表示每次从数据集中取多少个数据

  • shuffle:表示是否为乱序取出,True表示前后不一样

  • num_workers :表示是否多进程读取数据(默认为0);

  • drop_last : 表示当样本数不能被batchsize整除时(即总数据集/batch_size 不能除尽,有余数时),最后一批数据(余数)是否舍弃(default:
    False)

  • pin_memory: 如果为True会将数据放置到GPU上去(默认为false)

还是以上文的CIFAR10的测试集为例

from torch.utils.data import DataLoader
import torchvision
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
Files already downloaded and verified
# 创建DataLoader实例
test_loader = DataLoader(dataset=test_set, # 引入数据集batch_size=4, # 每次取4个数据shuffle=True, # 打乱顺序num_workers=0, # 非多进程drop_last=False # 最后数据(余数)不舍弃
)

利用DataLoader的完整代码如下

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 准备测试集
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)# 创建test_loader实例
test_loader = DataLoader(dataset=test_set, # 引入数据集batch_size=4, # 每次取4个数据shuffle=True, # 打乱顺序num_workers=0, # 非多进程drop_last=False # 最后数据(余数)不舍弃
)img,index = test_set[0]
print(img.shape) # 查看图片大小 torch.Size([3, 32, 32]) C h w,即三通道 32*32
print(index) # 查看图片标签
# 遍历test_loader
for data in test_loader:img,target = dataprint(img.shape) # 查看图片信息torch.Size([4, 3, 32, 32])表示一次4张图片,图片为3通道RGB,大小为32*32print(target)  # tensor([4, 9, 8, 8])表示4张图片的target
# 在tensorboard 中显示
writer = SummaryWriter("logs")
step = 0
for data in test_loader:img, target = datawriter.add_images("test_loader",img,step)step = step+1
writer.close()

tensorboard显示如下

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/170604.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vscode安装可以打开docx文件的插件

文章目录 vscode安装可以打开docx文件的插件去插件商城搜索并安装安装后打开一个word vscode安装可以打开docx文件的插件 去插件商城搜索并安装 安装后 打开一个word

脏牛提权 liunx

使用方法 Liunx 普通用户 内核版本 在版本里 我直接脏牛提权 有脚本查看内核版本 上传c脚本 编译 直接执行 获取高权限 提权 Liunx https://github.com/InteliSecureLabs/Linux Exploit Suggester 运行这个脚本 上传到客户端 https://github…

Unity Hub报错:No valid Unity Editor license found. Please activate your license.

最近 遇到一个问题,打开高版本时Hub抛出异常:No valid Unity Editor license found. Please activate your license. 首先你必须排除是否登录Unity Hub,并且激活许可证。 方法一:禁用网络(这个可能无效) …

如何通过卖虚拟资料月入10万?看这几个卖资料案例

我微信好友里,有近4000个是做创业博主的同行。 你可能会好奇,其中60%的人都通过卖虚拟资料起家,这到底说明了什么呢? 嗯,事实上,这就意味着这些人选择了网络赚钱的首选项目,那就是销售各种资料…

python selenium如何带cookie访问网站

python selenium如何带cookie访问网站 要使用Python的Selenium库带有cookie访问网站,你可以按照以下步骤进行操作: 一、流程介绍 安装Selenium库(如果尚未安装): pip install selenium导入Selenium库并启动一个浏览…

导致爬虫无法使用的原因有哪些?

随着互联网的普及和发展,爬虫技术也越来越多地被应用到各个领域。然而,在实际使用中,爬虫可能会遇到各种问题导致无法正常工作。本文将探讨导致爬虫无法使用的原因,并给出相应的解决方法。 一、目标网站反爬虫机制 许多网站为了…

11 个最值得推荐的 Windows 数据恢复软件

您可能已经尝试过许多免费的恢复程序,但它们都不起作用,对吧?这就是您正在寻找最好的数据恢复软件的原因。 个人去过那里。根据个人的经验,大多数免费软件并不能解决这个问题。有时,当个人在 PC 上运行恢复程序时&…

【API篇】七、Flink窗口

文章目录 1、窗口2、分类3、窗口API概览4、窗口分配器 在批处理统计中,可以等待一批数据都到齐后,统一处理。但是在无界流的实时处理统计中,是来一条就得处理一条,那么如何统计最近一段时间内的数据呢? ⇒ 窗口的概念&…

选择合适的项目管理系统来支持专业产品研发团队

专业产品研发团队的公司离不开其严谨的管理和高效的研发流程,为了进一步提升研发效率和管理水平,产研团队需要一个全流程的项目管理系统来支持其研发团队的协同合作。 一、系统需求 IT行业的研发工作涵盖了从立项、项目变更到项目的进程计划等多个环节。…

B-tree(PostgreSQL 14 Internals翻译版)

概览 B树(作为B树访问方法实现)是一种数据结构,它使您能够通过从树的根向下查找树的叶节点中所需的元素。为了明确地标识搜索路径,必须对所有树元素进行排序。B树是为有序数据类型设计的,这些数据类型的值可以进行比较和排序。 下面的机场代…

OpenCV视频车流量识别详解与实践

视频车流量识别基本思想是使用背景消去算法将运动物体从图片中提取出来,消除噪声识别运动物体轮廓,最后,在固定区域统计筛选出来符合条件的轮廓。 基于统计背景模型的视频运动目标检测技术: 背景获取:需要在场景存在…

基于PHP的线上购物商城,MySQL数据库,PHPstudy,原生PHP,前台用户+后台管理,完美运行,有一万五千字论文。

目录 演示视频 基本介绍 论文截图 功能结构 系统截图 演示视频 基本介绍 基于PHP的线上购物商城,MySQL数据库,PHPstudy,原生PHP,前台用户后台管理,完美运行,有一万五千字论文。 现如今,购物网站是商业…

【七】SpringBoot为什么可以打成 jar包启动

SpringBoot为什么可以打成 jar包启动 简介:庆幸的是夜跑的习惯一直都在坚持,正如现在坚持写博客一样。最开始刚接触springboot的时候就觉得很神奇,当时也去研究了一番,今晚夜跑又想起来了这茬事,于是想着应该可以记录一…

Python单元测试

import unittest #必须要导入单元测试的包class Student(object):def __init__(self, name, score):self.name nameself.score scoredef get_grade(self):if self.score > 100:#返回错误不能用return,应该用raise raise ValueError("成绩不能大于100"…

MySQL进阶(日志)——MySQL的日志 bin log (归档日志) 事务日志redo log(重做日志) undo log(回滚日志)

前言 MySQL最为最流行的开源数据库,其重要性不言而喻,也是大多数程序员接触的第一款数据库,深入认识和理解MySQL也比较重要。 本篇博客阐述MySQL的日志,介绍重要的bin log (归档日志) 、 事务日志redo log(重做日志) 、 undo lo…

【iOS逆向与安全】某音App直播间自动发666 和 懒人自动看视频

1.目标 由于看直播的时候主播叫我发 666,支持他,我肯定支持他呀,就一直发,可是后来发现太浪费时间了,能不能做一个直播间自动发 666 呢?于是就花了几分钟做了一个。 2.操作环境 越狱iPhone一台 frida m…

Mybatis应用场景之动态传参、两字段查询、用户存在性的判断

目录 一、动态传参 1、场景描述 2、实现过程 3、代码测试 二、两字段查询 1、场景描述 2、实现过程 3、代码测试 4、注意点 三、用户存在性的判断 1、场景描述 2、实现过程 3、代码测试 一、动态传参 1、场景描述 在进行数据库查询的时候,需要动态传入…

Linux友人帐之日志与备份

一、日志 1.1概述 日志文件是重要的系统信息文件,其中记录了许多重要的系统事件,包括用户的登录信息、系统的启动信息、系统的安全信息、邮件相关信息、各种服务相关信息等。日志对于安全来说也很重要,它记录了系统每天发生的各种事情&#…

openGauss学习笔记-108 openGauss 数据库管理-管理用户及权限-用户

文章目录 openGauss学习笔记-108 openGauss 数据库管理-管理用户及权限-用户108.1 创建、修改和删除用户108.2 私有用户108.3 永久用户108.4 用户认证优先规则 openGauss学习笔记-108 openGauss 数据库管理-管理用户及权限-用户 使用CREATE USER和ALTER USER可以创建和管理数据…

期中考核复现

web 1z_php ?0o0[]1A&OoO[]2023a include "flag.php":尝试包含名为 "flag.php" 的文件。这意味着它会尝试引入一个名为 "flag.php" 的脚本文件,其中可能包含一些敏感信息或标志。 error_reporting(0):…