神经网络的数据集处理

离不开这个库torch.utils.data,这个库有两个类一个Dataset和Dataloader

Dataset(对单个样本处理)

Dataset 是一个非常重要的概念,它主要用于管理和组织数据,方便后续的数据加载和处理。以下以 PyTorch 为例,详细介绍 Dataset 相关内容。

概述

在 PyTorch 里,torch.utils.data.Dataset 是一个抽象类,所有自定义的数据集都应该继承这个类,并且至少要实现 __len__ 和 __getitem__ 这两个方法。

  • __len__ 方法:返回数据集的样本数量。
  • __getitem__ 方法:根据给定的索引返回对应的样本和标签。(自定义)

内置的Dataset

这是用于加载手写数字 MNIST 数据集的类,常见参数如下:

  • root:指定数据集存储的根目录。如果数据不存在,下载的数据将保存到该目录下。
  • train:一个布尔值,True 表示加载训练集,False 表示加载测试集。
  • transform:用于对图像数据进行预处理的转换操作,例如将图像转换为张量、归一化等。
  • target_transform:用于对标签数据进行预处理的转换操作。
  • download:一个布尔值,True 表示如果数据集不存在则自动下载,False 表示不进行下载
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载 MNIST 训练集
train_dataset = datasets.MNIST(root='./data', train=True,download=True, transform=transform)

自定义数据集(重点)

讲一下什么标注文件 就拿视频而言 我所见的到 一般都是两者 要么就是 你自己分好帧的文件路径 帧数 以及标签(这个标签一般是数字),要么就是视频路径(未分帧数)加上标签

今天拿简单的举例(引用X-CLIP源码)

1.父类初始化


class BaseDataset(Dataset, metaclass=ABCMeta):def __init__(self,ann_file,#标注文件pipeline,#管道操作repeat = 1,#数据重复次数data_prefix=None,#数据文件的前缀路径,用于指定视频数据存储的目录test_mode=False,#一个布尔值,指示是否处于测试模式,默认为 Falsemulti_class=False,#一个布尔值,指示是否为多分类任务,默认为 Falsenum_classes=None,#分类的类别数量,在多分类任务中需要指定,默认为 Nonestart_index=1,#视频帧的起始索引,默认为 1modality='RGB',#视频数据的模态,如 RGB 等,默认为 'RGB'sample_by_class=False,#个布尔值,指示是否按类别采样数据,默认为 Falsepower=0,#用于调整按类别采样的概率的幂次,默认为 0dynamic_length=False,):#个布尔值,指示是否使用动态长度的数据,默认为 False。super().__init__()#处理一下这里判断 data_prefix 中是否包含 .tar 字符串,如果包含则将 self.use_tar_format 设置为 True,表示使用 .tar 格式存储数据。#同时,将 data_prefix 中的 .tar 字符串替换为空,方便后续处理self.use_tar_format = True if ".tar" in data_prefix else Falsedata_prefix = data_prefix.replace(".tar", "")self.ann_file = ann_fileself.repeat = repeatself.data_prefix = osp.realpath(data_prefix) if data_prefix is not None and osp.isdir(data_prefix) else data_prefixself.test_mode = test_modeself.multi_class = multi_classself.num_classes = num_classesself.start_index = start_indexself.modality = modalityself.sample_by_class = sample_by_classself.power = powerself.dynamic_length = dynamic_lengthassert not (self.multi_class and self.sample_by_class)self.pipeline = Compose(pipeline)self.video_infos = self.load_annotations()if self.sample_by_class:self.video_infos_by_class = self.parse_by_class()class_prob = []for _, samples in self.video_infos_by_class.items():class_prob.append(len(samples) / len(self.video_infos))class_prob = [x**self.power for x in class_prob]summ = sum(class_prob)class_prob = [x / summ for x in class_prob]self.class_prob = dict(zip(self.video_infos_by_class, class_prob))

1.在 BaseDataset 类中,repeat 参数的主要作用是控制数据的重复使用次数,通常用于数据增强或者调整数据集的有效规模

2.sample_by_class:在一些数据集里,不同类别的样本数量可能存在较大差异,即存在类别不平衡问题。按类别采样可以确保每个类别在训练过程中都有足够的样本被使用,避免模型过度偏向样本数量多的类别,有助于提高模型对各个类别的分类性能。

power:作用:当 sample_by_class 为 True 时,不同类别的采样概率最初是根据该类别样本数量占总样本数量的比例来计算的。使用 power 参数可以对这些概率进行调整。如果 power 大于 1,会增大样本数量少的类别的采样概率,使得这些类别在采样中更有可能被选中;如果 power 小于 1 且大于 0,会减小样本数量少的类别的采样概率;当 power 等于 0 时,所有类别的采样概率相等。

4dynamic_length:在一些数据集中,数据样本的长度可能是不同的。例如,在处理文本数据时,不同句子的长度可能不一样;在处理视频数据时,不同视频的帧数也可能不同。使用动态长度的数据可以更灵活地处理这些情况,避免对数据进行不必要的截断或填充操作,从而保留更多的数据信息。但同时,动态长度的数据处理起来相对复杂,需要特殊的处理机制。

2.成员函数

1.处理标注文件 返回一个列表

@abstractmethoddef load_annotations(self):"""Load the annotation according to ann_file into video_infos."""# json annotations already looks like video_infos, so for each dataset,# this func should be the samevideo_infos = []with open(self.ann_file, 'r') as fin:for line in fin:line_split = line.strip().split()if self.multi_class:assert self.num_classes is not Nonefilename, label = line_split[0], line_split[1:]label = list(map(int, label))else:filename, label = line_splitlabel = int(label)if self.data_prefix is not None:filename = osp.join(self.data_prefix, filename)video_infos.append(dict(filename=filename, label=label, tar=self.use_tar_format))return video_infos
 [{'filename': 'videos/sample_01.mp4', 'label': 0, 'tar': False},{'filename': 'videos/sample_02.mp4', 'label': 1, 'tar': True}
]

 2.parse_by_class

将 video_infos 中的数据按类别进行分组,返回一个字典,键为类别标签,值为该类别对应的视频信息列表

def parse_by_class(self):video_infos_by_class = defaultdict(list)#defaultdict 是 Python 中 collections 模块提供的一个特殊字典,当访问一个不存在的键时,它会自动创建一个默认值。这里指定默认值为一个空列表 list。for item in self.video_infos:label = item['label']video_infos_by_class[label].append(item)return video_infos_by_class

 3.def label2array(num, label) one-hot数据

    def label2array(num, label):arr = np.zeros(num, dtype=np.float32)arr[label] = 1.return arr

4.dump_results 储存数据

 @staticmethoddef dump_results(results, out):"""Dump data to json/yaml/pickle strings or files."""return mmcv.dump(results, out)

5.准备训练帧( self.video_infos = self.load_annotations())他是一个列表

 def prepare_train_frames(self, idx):"""Prepare the frames for training given the index."""results = copy.deepcopy(self.video_infos[idx])results['modality'] = self.modalityresults['start_index'] = self.start_index# prepare tensor in getitem# If HVU, type(results['label']) is dictif self.multi_class and isinstance(results['label'], list):onehot = torch.zeros(self.num_classes)onehot[results['label']] = 1.results['label'] = onehotaug1 = self.pipeline(results)if self.repeat > 1:aug2 = self.pipeline(results)ret = {"imgs": torch.cat((aug1['imgs'], aug2['imgs']), 0),"label": aug1['label'].repeat(2),}return retelse:return aug1def prepare_test_frames(self, idx):"""Prepare the frames for testing given the index."""results = copy.deepcopy(self.video_infos[idx])results['modality'] = self.modalityresults['start_index'] = self.start_index# prepare tensor in getitem# If HVU, type(results['label']) is dictif self.multi_class and isinstance(results['label'], list):onehot = torch.zeros(self.num_classes)onehot[results['label']] = 1.results['label'] = onehotreturn self.pipeline(results)

子类:处理video

class VideoDataset(BaseDataset):def __init__(self, ann_file, pipeline, labels_file, start_index=0, **kwargs):super().__init__(ann_file, pipeline, start_index=start_index, **kwargs)self.labels_file = labels_file@propertydef classes(self):classes_all = pd.read_csv(self.labels_file)return classes_all.values.tolist()def load_annotations(self):"""Load annotation file to get video information."""if self.ann_file.endswith('.json'):return self.load_json_annotations()video_infos = []with open(self.ann_file, 'r') as fin:for line in fin:line_split = line.strip().split()if self.multi_class:assert self.num_classes is not Nonefilename, label = line_split[0], line_split[1:]label = list(map(int, label))else:filename, label = line_splitlabel = int(label)if self.data_prefix is not None:filename = osp.join(self.data_prefix, filename)video_infos.append(dict(filename=filename, label=label, tar=self.use_tar_format))return video_infos

最终:

train_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT,labels_file=config.DATA.LABEL_LIST, pipeline=train_pipeline)

怎么看数据集返回什么? 看子类实现了了__getitem__(video没有实现,去上面找发现 

返回了管道操作(等下数据增强的时候讲)之后的数据

 def __getitem__(self, idx):"""Get the sample for either training or testing given index."""if self.test_mode:return self.prepare_test_frames(idx)return self.prepare_train_frames(idx)
->def prepare_test_frames(self, idx):"""Prepare the frames for testing given the index."""results = copy.deepcopy(self.video_infos[idx])results['modality'] = self.modalityresults['start_index'] = self.start_index# prepare tensor in getitem# If HVU, type(results['label']) is dictif self.multi_class and isinstance(results['label'], list):onehot = torch.zeros(self.num_classes)onehot[results['label']] = 1.results['label'] = onehotreturn self.pipeline(results)->self.pipeline = Compose(pipeline)
->train_pipeline = [dict(type='DecordInit'),dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES),dict(type='DecordDecode'),dict(type='Resize', scale=(-1, scale_resize)),dict(type='MultiScaleCrop',input_size=config.DATA.INPUT_SIZE,scales=(1, 0.875, 0.75, 0.66),random_crop=False,max_wh_scale_gap=1),dict(type='Resize', scale=(config.DATA.INPUT_SIZE, config.DATA.INPUT_SIZE), keep_ratio=False),dict(type='Flip', flip_ratio=0.5),dict(type='ColorJitter', p=config.AUG.COLOR_JITTER),dict(type='GrayScale', p=config.AUG.GRAY_SCALE),dict(type='Normalize', **img_norm_cfg),dict(type='FormatShape', input_format='NCHW'),dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),dict(type='ToTensor', keys=['imgs', 'label']),]->train_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT,labels_file=config.DATA.LABEL_LIST, pipeline=train_pipeline)

 结果返回了:返回一个字典:images:包含一个[N  C H W]为一个采样帧数,labels:为一个独热张量

Dataloader(批次)

   train_loader = DataLoader(train_data, sampler=sampler_train,batch_size=config.TRAIN.BATCH_SIZE,num_workers=16,pin_memory=True,drop_last=True,collate_fn=partial(mmcv_collate, samples_per_gpu=config.TRAIN.BATCH_SIZE),)

参数解析:

dataset
类型:torch.utils.data.Dataset 子类的实例
作用:指定要加载的数据集,例如前面提到的自定义数据集类的实例。
batch_size
类型:int
作用:每个批次加载的样本数量,默认为 1。例如,如果 batch_size = 32,则每次从数据集中加载 32 个样本。
num_workers
类型:int
作用:使用的子进程数量来加载数据。设置为 0 表示数据将在主进程中加载,
大于 0 则使用多进程并行加载数据,提高数据加载速度,默认为 0
sampler
类型:torch.utils.data.Sampler 子类的实例
作用:自定义样本采样策略。如果指定了 sampler,则 shuffle 参数将被忽略。例如,可以使用 WeightedRandomSampler 进行加权随机采样。
collate_fn
类型:callable
作用:自定义批量数据整理函数,用于将多个样本组合成一个批次。
例如前面提到的 mmcv_collate 函数。如果不指定,将使用默认的整理函数。
drop_last
类型:bool
作用:如果数据集的样本数量不能被 batch_size 整除,是否丢弃最后一个不完整的批次。设置为 True 则丢弃,
设置为 False 则保留,默认为 False

返回值:

DataLoader 是一个可迭代对象,当使用 for 循环遍历 DataLoader 时,每次迭代会返回一个批次的数据。返回的数据格式取决于 dataset 的 __getitem__ 方法和 collate_fn 函数。

def mmcv_collate(batch, samples_per_gpu=1): if not isinstance(batch, Sequence):raise TypeError(f'{batch.dtype} is not supported.')if isinstance(batch[0], Sequence):transposed = zip(*batch)return [collate(samples, samples_per_gpu) for samples in transposed]elif isinstance(batch[0], Mapping):return {key: mmcv_collate([d[key] for d in batch], samples_per_gpu)for key in batch[0]}else:return default_collate(batch)

 mmcv_collate 函数将多个样本组合成一个批次,最终输出的 batch 是一个字典,包含 'img' 和 'label' 两个键,对应的值分别是批量处理后的图像数据和标签数据

常见的返回形式有元组(包含输入数据和标签)

这个是返回:字典{images:[B,T,C,H,W],labels:}

使用方法

字典:

  for idx, batch_data in enumerate(train_loader):images = batch_data["imgs"].cuda(non_blocking=True)label_id = batch_data["label"].cuda(non_blocking=True)

元组:

   for iii, (image, class_id) in enumerate(tqdm(val_loader)):

图像处理->视频

 images = images.view((-1,config.DATA.NUM_FRAMES,3)+images.size()[-2:])

由NCHW ->N T C H W 记住这个!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/35571.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯每日一题

丢失的雨伞 题目思路代码演示 题目 今天晚上本来想练习一下前缀和与差分 结果给我搜出来这题(几乎没啥关系),我看半天有点思路但又下不了手哈哈,难受一批 在图书馆直接红温了 题目链接 思路 题目要求找到两个不重叠的区间&…

校园安全用电怎么保障?防触电装置来帮您

引言 随着教育设施的不断升级和校园用电需求的日益增长,校园电力系统的安全性和可靠性成为了学校管理的重要课题。三相智能安全配电装置作为一种电力管理设备,其在校园中的应用不仅能够提高电力系统的安全性,还能有效保障师生的用电安全&am…

Matlab 汽车二自由度转弯模型

1、内容简介 Matlab 187-汽车二自由度转弯模型 可以交流、咨询、答疑 2、内容说明 略 摘 要 本文前一部分提出了侧偏角和横摆角速度作为参数。描述了车辆运动的运动状态,其中文中使用的参考模型是二自由度汽车模型。汽车速度被认为是建立基于H.B.Pacejka的轮胎模…

OpenCV计算摄影学(20)非真实感渲染之增强图像的细节函数detailEnhance()

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 此滤波器增强特定图像的细节。 cv::detailEnhance用于增强图像的细节,通过结合空间域和频率域的处理,提升图像中特定细节…

Java面试八股—Redis篇

一、Redis的使用场景 (一)缓存 1.Redis使用场景缓存 场景:缓存热点数据(如用户信息、商品详情),减少数据库访问压力,提升响应速度。 2.缓存穿透 正常的访问是:根据ID查询文章&…

2025-03-17 Unity 网络基础1——网络基本概念

文章目录 1 网络1.1 局域网1.2 以太网1.3 城域网1.4 广域网1.5 互联网(因特网)1.6 万维网1.7 小结 2 IP 地址2.1 IP 地址2.2 端口号2.3 Mac 地址2.4 小结 3 客户端与服务端3.1 客户端3.2 服务端3.3 网络游戏中的客户端与服务端 1 网络 ​ 在没有网络之前…

【工业现场总线】控制网络的主要特点是?OSI参考模型的分层是?

目录 1、控制网络的主要特点? 2、网络拓扑结构的主要类型?其各自主要特点是什么? 3、网络的传输介质主要有什么? 4、网络传输介质的访问控制方式主要有哪些?其各自主要特点是什么? 5、OSI参考模型的分…

微软开源神器OmniParser V2.0 介绍

微软开源的OmniParser V2.0是一款基于纯视觉技术的GUI智能体解析工具,旨在将用户界面(UI)截图转换为结构化数据,从而实现对计算机屏幕上的可交互元素的高效识别和操控。这一工具通过结合先进的视觉解析技术和大型语言模型&#xf…

用python代码将excel中的数据批量写入Json中的某个字段,生成新的Json文件

需求 需求: 1.将execl文件中的A列赋值给json中的TrackId,B列赋值给json中的OId 要求 execl的每一行,对应json中的每一个OId json 如下: {"List": [{"BatchNumber": "181-{{var}}",// "Bat…

实验篇| Nginx环境搭建-安全配置

在前面的文章里,阿祥详细介绍了在 Windows 系统中安装 Nginx 服务器的具体操作步骤,感兴趣的朋友可以参考:实验篇 | Nginx 反向代理 - 7 层代理 。完成 Nginx 的安装只是搭建 Web 服务的第一步,为了保障服务器的稳定运行以及数据安…

理解我们单片机拥有的资源

目录 为什么要查询单片机拥有的资源 所以,去哪些地方可以找数据手册 一个例子:STM32F103C8T6 前言 本文章隶属于项目: Charliechen114514/BetterATK: This is a repo that helps rewrite STM32 Common Repositorieshttps://github.com/C…

从零开始 | C语言基础刷题DAY3

❤个人主页&#xff1a;折枝寄北的博客 目录 1.打印3的倍数的数2.从大到小输出3. 打印素数4.打印闰年5.最大公约数 1.打印3的倍数的数 题目&#xff1a; 写一个代码打印1-100之间所有3的倍数的数字 代码&#xff1a; int main(){int i 0;for (i 1; i < 100; i){if (i % …

Blender材质 - 层权重

层权重 混合着色器 可以让 面朝向的一面显示一种材质 另一面显示另一种材质 就能实现挺不错的材质效果 移动视角 材质会跟着变化 有点类似虚幻的视差节点BumpOffset

3个 Vue $set 的应用场景

大家好&#xff0c;我是大澈&#xff01;一个喜欢结交朋友、喜欢编程技术和科技前沿的老程序员&#x1f468;&#x1f3fb;‍&#x1f4bb;&#xff0c;关注我&#xff0c;科技未来或许我能帮到你&#xff01; 在 Vue2 中&#xff0c;由于 Object.defineProperty 的限制&#…

Flutter_学习记录_ ImagePicker拍照、录制视频、相册选择照片和视频、上传文件

插件地址&#xff1a;https://pub.dev/packages/image_picker 添加插件 添加配置 android无需配置开箱即用&#xff0c;ios还需要配置info.plist <key>NSPhotoLibraryUsageDescription</key> <string>应用需要访问相册读取文件</string> <key>N…

LeetCode 解题思路 19(Hot 100)

解题思路&#xff08;递归&#xff09;&#xff1a; 终止条件&#xff1a; 若节点为空&#xff0c;返回深度0。递归步骤&#xff1a; 分别计算左子树和右子树的最大深度&#xff0c;取较大者并加1&#xff08;当前节点&#xff09;。 Java代码&#xff1a; class Solution {…

如何启用 HTTPS 并配置免费的 SSL 证书

引言 HTTPS 已成为现代网站安全性的基础要求。通过 SSL/TLS 证书对数据进行加密&#xff0c;不仅可以保护用户隐私&#xff0c;还能提升搜索引擎排名并增强用户信任。本指南将详细介绍如何通过 Lets Encrypt&#xff08;免费、自动化的证书颁发机构&#xff09;为您的网站启用…

element-plus中Popconfirm气泡确认框组件的使用

1、基本使用 从element-plus官网复制代码&#xff1a; <template><el-popconfirm title"Are you sure to delete this?"><template #reference><el-button>Delete</el-button></template></el-popconfirm> </template…

软件需求分类、需求获取(高软46)

系列文章目录 软件需求分类&#xff0c;需求获取 文章目录 系列文章目录前言一、软件需求二、获取需求三、真题总结 前言 本节讲明软件需求分类、需求获取的相关知识。 一、软件需求 二、获取需求 三、真题 总结 就是高软笔记&#xff0c;大佬请略过&#xff01;

10、基于osg引擎生成热力图高度图实现3D热力图可视化、3D热力图实时更新(带过渡效果)

1、结果 2、完整C代码 #include <sstream> #include <iomanip> #include <iostream> #include <vector> #include <random> #include <cmath> #include <functional> #include <osgViewer/viewer> #include <osgDB/Read…