【PyTorch简介】3.Loading and normalizing datasets 加载和规范化数据集

Loading and normalizing datasets 加载和规范化数据集

文章目录

  • Loading and normalizing datasets 加载和规范化数据集
  • Datasets & DataLoaders 数据集和数据加载器
  • Loading a Dataset 加载数据集
  • Iterating and Visualizing the Dataset 迭代和可视化数据集
  • Creating a Custom Dataset for your files 为您的文件创建自定义数据集
    • \__init__
    • \__len__
    • \__getitem__
  • Preparing your data for training with DataLoaders 使用 DataLoaders 准备数据以进行训练
  • Iterate through the DataLoader 遍历 DataLoader
  • Normalization 正则化
    • Transforms 转换
    • ToTensor()
    • Lambda transforms
  • 知识检查
  • Further Reading 进一步阅读
  • References 参考文献
  • Github

Datasets & DataLoaders 数据集和数据加载器

用于处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望数据集代码与模型训练代码分离,以获得更好的可读性和模块化性。PyTorch 提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset ,允许您使用预加载的数据集以及您自己的数据。 Dataset存储样本及其相应的标签,并DataLoader围绕 Dataset进行迭代,以方便访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST)。这些数据集是torch.utils.data.Dataset的子类。并且,对于特定数据,实现特定的函数。它们可用于对您的模型进行原型设计和基准测试。您可以在这里找到它们:图像数据集、 文本数据集和 音频数据集

Loading a Dataset 加载数据集

以下是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 论文的图像数据集。这个数据集由 60,000 个训练样本和 10,000 个测试样本组成。每个样本包含一个 28×28 灰度图像和来自 10 个类别之一的关联标签。

  • 每张图像的高度为 28 像素,宽度为 28 像素,总共 784 像素。
  • 这 10 个类别表示图像的类型,例如:T 恤/上衣、裤子、套头衫、连衣裙、包、踝靴等.
  • 灰度像素的值介于 0 到 255 之间,用于测量黑白图像的强度。强度值从白色增加到黑色。例如:白色为 0,黑色为 255。

我们使用以下参数,来加载FashionMNIST Dataset:

  • root 是存储训练/测试数据的路径,

  • train 指定训练或测试数据集,

  • download=True 如果root 上没有数据,则从 Internet 下载数据。

  • transformtarget_transform 指定特征和标签的转换。

%matplotlib inline
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

Iterating and Visualizing the Dataset 迭代和可视化数据集

我们可以像列表一样手动索引Datasetstraining_data[index]。我们用matplotlib来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Out:
在这里插入图片描述

Creating a Custom Dataset for your files 为您的文件创建自定义数据集

自定义 Dataset 类必须实现三个函数:__init____len____getitem__。看看这个实现:FashionMNIST 图像存储在目录img_dir中,它们的标签单独存储在CSV 文件annotations_file中。

在接下来的部分中,我们将详细介绍每个函数中实现的功能。

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

_init_

__init__ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、注释文件和两种转换的目录(下一节将更详细地介绍)。

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

_len_

__len__ 函数返回数据集中的样本数。

例子:

def __len__(self):return len(self.img_labels)

_getitem_

__getitem__ 函数从数据集加载并返回给定索引idx的的样本。基于索引,它识别图像在磁盘上的位置,使用read_image 将其转换为张量,从 self.img_labels中的 csv 数据中检索相应的标签,调用它们的转换函数(如果适用),并以元组方式返回张量图像和相应的标签。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

Preparing your data for training with DataLoaders 使用 DataLoaders 准备数据以进行训练

Dataset检索我们的数据集中一个样本的特征和标签。在训练模型时,我们通常希望以 “minibatches”方式传递样本,在每个epcoch重新整理数据以减少模型过度拟合,并使用 Python的multiprocessing来加速数据检索。

在机器学习中,您需要指定数据集中的特征和标签。输入特征,输出标签。我们训练特征,然后训练模型来预测标签。

  • 特征是图像像素中的图案
  • 标签是我们的 10 类类型:T 恤、凉鞋、连衣裙等

DataLoader是一个可迭代对象,它通过一个简单的 API 为我们抽象了这种复杂性。要使用 Dataloader,我们需要设置以下参数:

  • data 将用于训练模型的训练数据,以及评估模型的测试数据
  • batch size 每批中要处理的记录数
  • shuffle 按索引随机抽取数据样本
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Iterate through the DataLoader 遍历 DataLoader

我们已将该数据集加载到 DataLoader 中,并且可以根据需要迭代数据集。下面的每次迭代都会返回一批train_featurestrain_labelsbatch_size=64分别包含特征和标签)。因为我们指定了shuffle=True,所以在迭代所有批次后,数据将被打乱(为了更细粒度地控制数据加载顺序,请查看Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
label_name = list(labels_map.values())[label]
print(f"Label: {label_name}")

在这里插入图片描述

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: Ankle Boot

Normalization 正则化

正则化是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素的值在0到255之间,这是特征。如果一个像素值为17,另一个像素为197。就会出现像素重要性分布不均匀的情况,因为较高的像素量会使学习发生偏差。正则化会改变数据的范围,而不会扭曲其特征之间的区别。进行这种预处理是为了避免:

  • 预测精度降低
  • 模型学习困难
  • 特征数据范围的不利分布

Transforms 转换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用transforms来操作数据并使其适合训练。

所有 TorchVision 数据集都有两个参数(transform 用于修改特征,target_transform 用于修改标签),它们接受包含转换逻辑的可调用对象。 torchvision.transforms 模块提供了几种开箱即用的常用转换。

FashionMNIST特征采用PIL图像格式,标签为整数。对于训练,我们需要将特征作为归一化张量,将标签作为单热编码张量。为了进行这些转换,我们将使用 ToTensorLambda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor 并将图像的像素强度值缩放到 [0., 1.]范围。

Lambda transforms

Lambda transforms 应用任何用户定义的 lambda 函数。在这里,我们定义一个函数将整数转换为 one-hot 编码张量。它首先创建一个大小为 10(数据集中的标签数量)的零张量,并调用 scatter,它在标签 y 给定的索引上分配 value=1。您还可以使用 torch.nn.function.one_hot 作为另一个选项来执行此操作。

知识检查

1.PyTorch DataSet 和 PyTorch DataLoader 之间有什么区别

DataSet 按设计用于检索单个数据项,而 DataLoader 按设计用于处理批量数据。

2.PyTorch 中的转换旨在:

对数据执行某些操作,使其适用于训练。

Further Reading 进一步阅读

  • torch.utils.data API

References 参考文献

使用 PyTorch 进行机器学习的简介 - Training | Microsoft Learn

使用 PyTorch 进行机器学习的简介 - Training | Microsoft Learn

Github

storm-ice/PyTorch_Fundamentals

storm-ice/PyTorch_Fundamentals

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/238063.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

memory泄露分析方法(java篇)

#memory泄露主要分为java和native 2种,本文主要介绍java# 测试每天从monkey中筛选出内存超标的app,提单流转到我 首先,辨别内存泄露类型(java,还是native) 从采到的dumpsys_meminfo_pid看java heap&…

网络技术基础入门全套实验-厦门微思网络CCNA实验手册

知识改变命运,技术就是要分享,有问题随时联系,免费答疑,欢迎联系! 微思简介(https://www.xmws.cn) 微思成立于2002年,是一个诚信敬业、积极向上、充满活力、专注技术服务的企业。 微思获得了八…

LeetCode讲解篇之2280. 表示一个折线图的最少线段数

文章目录 题目描述题解思路题解代码 题目描述 题解思路 折线图中如果连续的线段共线,那么我们可以可以将其合并成一条线段 首先将坐标点按照横坐标升序排序 然后遍历数组 我们可以通过计算前一个线段的斜率和当前线段的斜率来判断是否共线 如果二者相等&#x…

Open3D 两片点云的最小/最大距离(23)

Open3D 两片点云的最小/最大距离(23) 一、效果展示二、使用步骤1.代码三、cloudcompare量距小工具一、效果展示 算法与实际量测的结果保持一致,输出最近距离和对应点 二、使用步骤 1.代码 import open3d as o3d import numpy as np# 读取点云数据 cloud_2 = o3d.io.re…

硬盘无法写入文件的解决方法 在Mac中的特殊符号如何打 tuxera ntfs for Mac 磁盘读写工具

今天将为大家介绍一下怎么在Mac中输入特殊符号,希望能够给大家带来帮助。 图:Mac中的特殊符号 苹果符号 按下ShiftOptionK就可以插入Apple logo了,不过要注意的是,在Windows可能直接显示为一个框框,而Linux系统则有可…

uni微信小程序强制用户更新版本

强制更新的代码参考官方文档 uni.getUpdateManager() | uni-app官网 我这边的如下: //检查版本更新const updateManager uni.getUpdateManager();updateManager.onCheckForUpdate(function (res) {// 请求完新版本信息的回调console.log(res.hasUpdate, "是…

基于java的SSM框架实现在线投稿网站系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架Vue实现在线投稿网站系统演示 摘要 随着计算机技术的飞速发展,稿件也已进入信息化时代。为了使稿件管理更高效、更科学,决定开发投稿审稿系统。 本文采用自顶向下的结构化的系统分析方法,阐述了一个功能全面的投稿审稿系统…

uniapp微信小程序投票系统实战 (SpringBoot2+vue3.2+element plus ) -全局异常统一处理实现

锋哥原创的uniapp微信小程序投票系统实战: uniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )_哔哩哔哩_bilibiliuniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )共计21条视频…

GPT实战系列-简单聊聊LangChain搭建本地知识库准备

GPT实战系列-简单聊聊LangChain搭建本地知识库准备 LangChain 是一个开发由语言模型驱动的应用程序的框架,除了和应用程序通过 API 调用, 还会: 数据感知 : 将语言模型连接到其他数据源 具有代理性质 : 允许语言模型与其环境交互 LLM大模型…

Linux-命名管道

文章目录 前言一、命名管道接口函数介绍二、使用步骤 前言 上章内容,我们介绍与使用了管道。上章内容所讲的,是通过pipe接口函数让操作系统给我们申请匿名管道进行进程间通信。 并且这种进程间通信一般只适用于父子进程之间,那么对于两个没有…

创建一个郭德纲相声GPTs

前言 在这篇文章中,我将分享如何利用ChatGPT 4.0辅助论文写作的技巧,并根据网上的资料和最新的研究补充更多好用的咒语技巧。 GPT4的官方售价是每月20美元,很多人并不是天天用GPT,只是偶尔用一下。 如果调用官方的GPT4接口&…

打造VR数字乡村文旅新品牌,VR全景技术助力乡村振兴

新年伊始,各地乡村特色产业都在蓬勃发展,让冬日里的乡村重新焕发了新的活力。并且在这个冬季,各地还依托生态资源优势,打造智慧乡村文旅新品牌,激活乡村消费活力,例如有些乡村利用空心村,打造多…

看完这篇带你了解大学生必考安全证书NISP详解

NISP证书详解 NISP证书介绍:NISP证书等级:NISP(一级)报名:NISP(一级)课程大纲:NISP(二级)报名NISP(二级)课程大纲NISP二级置换CISP指南…

NLP论文阅读记录 - 2021 | WOS 使用深度强化学习及其他技术进行自动文本摘要

文章目录 前言0、论文摘要一、Introduction1.1目标问题1.2相关的尝试1.3本文贡献 二.相关工作2.1. Seq2seq 模型2.2.强化学习和序列生成2.3.自动文本摘要 三.本文方法四 实验效果4.1数据集4.2 对比模型4.3实施细节4.4评估指标4.5 实验结果4.6 细粒度分析 五 总结思考 前言 Auto…

IPv6组播--SSM Mapping

概念 SSM(Source-Specific Multicast)称为指定源组播,要求路由器能了解成员主机加入组播组时所指定的组播源。 如果成员主机上运行MLDv2,可以在MLDv2报告报文中直接指定组播源地址。但是某些情况下,成员主机只能运行MLDv1,为了使其也能够使用SSM服务,组播路由器上需要提…

【野火i.MX6ULL开发板】开发板连接网络(WiFi)与 SSH 登录、上电自动登录、设置静态IP、板子默认参数

0、前言 参考之前自己写的: http://t.csdnimg.cn/g60P8 参考资料: [野火]《Linux基础与应用开发实战指南——基于i.MX6ULL开发板》_20230323 从野火官网下载 参考博客: http://t.csdnimg.cn/8uh4O 参考官方文档: https://doc.…

使用pygame实现简单的烟花效果

import pygame import sys import random import math# 初始化 Pygame pygame.init()# 设置窗口大小 width, height 800, 600 screen pygame.display.set_mode((width, height)) pygame.display.set_caption("Fireworks Explosion")# 定义颜色 black (0, 0, 0) wh…

2023.1.13 关于在 Spring 中操作 Redis 服务器

目录 引言 前置工作 前置知识 实例演示 String 类型 List 类型 Set 类型 Hash 类型 ZSet 类型 引言 进行下述操作的前提是 你的云服务器已经配置好了 ssh 端口转发即已经将云服务器的 Redis 端口映射到本地主机 注意: 此处我们配置的端口号为 8888 可点击下…

【数据结构与算法】之数组系列-20240115

这里写目录标题 一、599. 两个列表的最小索引总和二、724. 寻找数组的中心下标三、面试题 16.11. 跳水板四、35. 搜索插入位置 一、599. 两个列表的最小索引总和 简单 假设 Andy 和 Doris 想在晚餐时选择一家餐厅,并且他们都有一个表示最喜爱餐厅的列表&#xff0c…

网络分流规则

现在的网络是越来越复杂。 有必要进行分流。 有一些geosite.dat是已经整理好的,包含许多的网站的分类: 分流规则: route规则 主要是: {"type": "field","outboundTag": "direct","domain&quo…