Pytorch建立MyDataLoader过程详解

简介

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)

详细:DataLoader

自己基于DataLoader实现各个模块

代码实现

MyDataset基于torch中的Data实现对个人数据集的载入,例如图像和标签载入
SingleSampler基于torch中的Sampler实现对于数据的batch个数图像的载入,例如,Batch_Size=4,实现对所有数据中选取4个索引作为一组,然后在MyDataset中基于__getitem__根据图像索引去进行图像操作
MyBathcSampler基于torch的BatchSampler实现自己对于batch_size数据的处理。需要基于SingleSampler实现Sampler的处理,更为灵活。MyBatchSampler的存在会自动覆盖DataLoader中的batch_size参数
注:Sampler的实现,将会与shuffer冲突,shuffer是在没有实现sampler前提下去自动判断选择的sampler类型
collate_fn是实现将batch_size的图像数据进行打包,遍历过程中就可以实现batch_size的images和labels对应
在这里插入图片描述

sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Samplerclass MyDataset(Dataset):def __init__(self) -> None:self.data = torch.arange(20)def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]@staticmethoddef collate_fn(batch):return torch.stack(batch, 0)class MyBatchSampler(BatchSampler):def __init__(self, sampler: Sampler[int], batch_size: int) -> None:self._sampler = samplerself._batch_size = batch_sizedef __iter__(self) -> Iterator[List[int]]:batch = []for idx in self._sampler:batch.append(idx)if len(batch) == self._batch_size:yield batchbatch = []yield batchdef __len__(self):return len(self._sampler) // self._batch_sizeclass SingleSampler(Sampler):def __init__(self, data_source) -> None:self._data = data_sourceself.num_samples = len(self._data)def __iter__(self):# 顺序采样# indices = range(len(self._data))# 随机采样indices = torch.randperm(self.num_samples).tolist()return iter(indices)def __len__(self):return self.num_samplestrain_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:print(data)

batch_sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Samplerclass MyDataset(Dataset):def __init__(self) -> None:self.data = torch.arange(20)def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]@staticmethoddef collate_fn(batch):return torch.stack(batch, 0)class MyBatchSampler(BatchSampler):def __init__(self, sampler: Sampler[int], batch_size: int) -> None:self._sampler = samplerself._batch_size = batch_sizedef __iter__(self) -> Iterator[List[int]]:batch = []for idx in self._sampler:batch.append(idx)if len(batch) == self._batch_size:yield batchbatch = []yield batchdef __len__(self):return len(self._sampler) // self._batch_sizeclass SingleSampler(Sampler):def __init__(self, data_source) -> None:self._data = data_sourceself.num_samples = len(self._data)def __iter__(self):# 顺序采样# indices = range(len(self._data))# 随机采样indices = torch.randperm(self.num_samples).tolist()return iter(indices)def __len__(self):return self.num_samplestrain_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:print(data)

参考

Sampler:https://blog.csdn.net/lidc1004/article/details/115005612

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/106012.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

简述docker的网络模式

Docker 提供了多种网络模式,用于控制容器之间以及容器与主机之间的网络通信。以下是 Docker 的一些常见网络模式 briage模式: docker容器启动时默认就是该模式,在该模式下,docker容器会连接到一个名为docker0的虚拟以太网桥上,通…

1.6 服务器处理客户端请求

客户端进程向服务器进程发送一段文本(MySQL语句),服务器进程处理后再向客户端进程发送一段文本(处理结果)。 从图中我们可以看出,服务器程序处理来自客户端的查询请求大致需要经过三个部分,分别…

类与对象(下)

类与对象(下) 一、初始化列表1、构造函数与初始化2、使用初始化列表的形式3、注意点4、代码5、类需初始化列表但没使用初始化列表时报的错误6、成员变量的初始化顺序(1)顺序(2)测试代码(3&#…

【中危】Apache XML Graphics Batik<1.17 存在SSRF漏洞 (CVE-2022-44729)

zhi.oscs1024.com​​​​​ 漏洞类型SSRF发现时间2023-08-23漏洞等级中危MPS编号MPS-2022-63578CVE编号CVE-2022-44729漏洞影响广度极小 漏洞危害 OSCS 描述Apache XML Graphics Batik 是一个开源的、用于处理可缩放矢量图形(SVG)格式图像的工具库。 受影响版本中&#xff0…

开源网安受邀参加软件供应链安全沙龙,推动企业提升安全治理能力

​8月23日下午,合肥软件行业软件供应链安全沙龙在中安创谷科技园举办。此次沙龙由合肥软件产业公共服务中心联合中安创谷科技园公司共同主办,开源网安软件供应链安全专家王晓龙、尹杰受邀参会并带来软件供应链安全方面的精彩内容分享,共同探讨…

Hightopo 使用心得(6)- 3D场景环境配置(天空球,雾化,辉光,景深)

在前一篇文章《Hightopo 使用心得(5)- 动画的实现》中,我们将一个直升机模型放到了3D场景中。同时,还利用动画实现了让该直升机围绕山体巡逻。在这篇文章中,我们将对上一篇的场景进行一些环境上的丰富与美化。让场景更…

Sentinel dashboard无法查询到应用的限流配置问题以及解决

一。问题引入 使用sentinle-dashboard控制台 项目整体升级后,发现控制台上无法看到流控规则了 之前的问题是无法注册上来 现在是注册上来了。结果看不到流控规则配置了。 关于注册不上来的问题,可以看另一篇文章 https://blog.csdn.net/a15835774652/…

Java之ApI之Math类详解

1 Math类 1.1 概述 tips:了解内容 查看API文档,我们可以看到API文档中关于Math类的定义如下: Math类所在包为java.lang包,因此在使用的时候不需要进行导包。并且Math类被final修饰了,因此该类是不能被继承的。 Math类…

ARM开发,stm32mp157a-A7核PWM实验(驱动蜂鸣器,风扇,马达工作)

1.分析框图; 2.比较捕获寄存器(产生PWM方波); 工作原理: 1、系统提供一个时钟源209MHZ,需要通过分频器进行分频,设置分频器值为209分频; 2、当定时器启动之后,自动重载…

编码基础一:侵入式链表

一、简介概述 1、普通链表数据结构 每个节点的next指针指向下一个节点的首地址。这样会有如下的限制: 一条链表上的所有节点的数据类型需要完全一致。对某条链表的操作如插入,删除等只能对这种类型的链表进行操作,如果链表的类型换了&#…

巨人互动|Facebook海外户Facebook游戏全球发布实用策略

Facebook是全球最大的社交媒体平台之一,拥有庞大的用户基数和广阔的市场。对于游戏开发商而言,利用Facebook进行全球发布是一项重要的策略。下面小编将介绍一些实用的策略帮助开发商在Facebook上进行游戏全球发布。 巨人互动|Facebook海外户&Faceboo…

visual studio 2022.NET Core 3.1 未显示在目标框架下拉列表中

问题描述 在Visual Studio 2022我已经安装了 .NET core 3.1 并验证可以运行 .NET core 3.1 应用程序,但当创建一个新项目时,目标框架的下拉列表只允许 .NET 6.0和7.0。而我在之前用的 Visual Studio 2019,可以正确地添加 .NET 核心项目。 …

【【萌新的STM32学习-17 中断的基本概念2】】

萌新的STM32学习-17 中断的基本概念2 STM32中断优先级的基本概念 抢占优先级: 高抢占优先级可以打断正在执行的低抢占优先级中断 响应优先级: 这个也叫子优先级 抢占优先级相同,响应优先级高的中断不能打断响应优先级低的中断。还有一种情况…

华为OD机试 - 字符串筛选排序 - 数组(Java 2022 Q4 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷&#…

深度学习优化入门:Momentum、RMSProp 和 Adam

目录 深度学习优化入门:Momentum、RMSProp 和 Adam 病态曲率 1牛顿法 2 Momentum:动量 3Adam 深度学习优化入门:Momentum、RMSProp 和 Adam 本文,我们讨论一个困扰神经网络训练的问题,病态曲率。 虽然局部极小值和鞍点会阻碍…

【深度学习】实验02 鸢尾花数据集分析

文章目录 鸢尾花数据集分析决策树K-means 鸢尾花数据集分析 决策树 # 导入机器学习相关库 from sklearn import datasets from sklearn import treeimport matplotlib.pyplot as plt import numpy as np# Iris数据集是常用的分类实验数据集, # 由Fisher, 1936收集…

NSSCTF——Web题目2

目录 一、[HNCTF 2022 Week1]2048 二、[HNCTF 2022 Week1]What is Web 三、[LitCTF 2023]1zjs 四、[NCTF 2018]签到题 五、[SWPUCTF 2021 新生赛]gift_F12 一、[HNCTF 2022 Week1]2048 知识点:源代码审计 解题思路: 1、打开控制台,查看…

Android中的APK打包与安全

aapt2命令行实现apk打包 apk文件结构 classes.dex:Dex,即Android Dalvik执行文件 AndroidManifest.xml:工程中AndroidManifest.xml编译后得到的二进制xml文件 META-INF:主要保存各个资源文件的SHA1 hash值,用于校验…

【Linux】【驱动】驱动挂载的时候给驱动传递参数

【Linux】【驱动】驱动挂载的时候给驱动传递参数 绪论1.什么是驱动传参驱动传参就是传递参数给我们的驱动举例:2.驱动传参数有什么作用呢?3. 传递单个参数使用如下的数组4. 传递数组使用以下函数: 传递数字值代码指令 传递数组代码传递数组指令 绪论 1.什么是驱动…

服务器中了mkp勒索病毒该怎么办?勒索病毒解密,数据恢复

mkp勒索病毒算的上是一种比较常见的勒索病毒类型了。它的感染数量上也常年排在前几名的位置。所以接下来就由云天数据恢复中心的技术工程师来对mkp勒索病毒做一个分析,以及中招以后应该怎么办。 一,中了mkp勒索病毒的表现 桌面以及多个文件夹当中都有一封…