离不开这个库torch.utils.data,这个库有两个类一个Dataset和Dataloader
Dataset(对单个样本处理)
Dataset
是一个非常重要的概念,它主要用于管理和组织数据,方便后续的数据加载和处理。以下以 PyTorch 为例,详细介绍 Dataset
相关内容。
概述
在 PyTorch 里,torch.utils.data.Dataset
是一个抽象类,所有自定义的数据集都应该继承这个类,并且至少要实现 __len__
和 __getitem__
这两个方法。
__len__
方法:返回数据集的样本数量。__getitem__
方法:根据给定的索引返回对应的样本和标签。(自定义)
内置的Dataset
这是用于加载手写数字 MNIST 数据集的类,常见参数如下:
root
:指定数据集存储的根目录。如果数据不存在,下载的数据将保存到该目录下。train
:一个布尔值,True
表示加载训练集,False
表示加载测试集。transform
:用于对图像数据进行预处理的转换操作,例如将图像转换为张量、归一化等。target_transform
:用于对标签数据进行预处理的转换操作。download
:一个布尔值,True
表示如果数据集不存在则自动下载,False
表示不进行下载
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载 MNIST 训练集
train_dataset = datasets.MNIST(root='./data', train=True,download=True, transform=transform)
自定义数据集(重点)
讲一下什么标注文件 就拿视频而言 我所见的到 一般都是两者 要么就是 你自己分好帧的文件路径 帧数 以及标签(这个标签一般是数字),要么就是视频路径(未分帧数)加上标签
今天拿简单的举例(引用X-CLIP源码)
1.父类初始化
class BaseDataset(Dataset, metaclass=ABCMeta):def __init__(self,ann_file,#标注文件pipeline,#管道操作repeat = 1,#数据重复次数data_prefix=None,#数据文件的前缀路径,用于指定视频数据存储的目录test_mode=False,#一个布尔值,指示是否处于测试模式,默认为 Falsemulti_class=False,#一个布尔值,指示是否为多分类任务,默认为 Falsenum_classes=None,#分类的类别数量,在多分类任务中需要指定,默认为 Nonestart_index=1,#视频帧的起始索引,默认为 1modality='RGB',#视频数据的模态,如 RGB 等,默认为 'RGB'sample_by_class=False,#个布尔值,指示是否按类别采样数据,默认为 Falsepower=0,#用于调整按类别采样的概率的幂次,默认为 0dynamic_length=False,):#个布尔值,指示是否使用动态长度的数据,默认为 False。super().__init__()#处理一下这里判断 data_prefix 中是否包含 .tar 字符串,如果包含则将 self.use_tar_format 设置为 True,表示使用 .tar 格式存储数据。#同时,将 data_prefix 中的 .tar 字符串替换为空,方便后续处理self.use_tar_format = True if ".tar" in data_prefix else Falsedata_prefix = data_prefix.replace(".tar", "")self.ann_file = ann_fileself.repeat = repeatself.data_prefix = osp.realpath(data_prefix) if data_prefix is not None and osp.isdir(data_prefix) else data_prefixself.test_mode = test_modeself.multi_class = multi_classself.num_classes = num_classesself.start_index = start_indexself.modality = modalityself.sample_by_class = sample_by_classself.power = powerself.dynamic_length = dynamic_lengthassert not (self.multi_class and self.sample_by_class)self.pipeline = Compose(pipeline)self.video_infos = self.load_annotations()if self.sample_by_class:self.video_infos_by_class = self.parse_by_class()class_prob = []for _, samples in self.video_infos_by_class.items():class_prob.append(len(samples) / len(self.video_infos))class_prob = [x**self.power for x in class_prob]summ = sum(class_prob)class_prob = [x / summ for x in class_prob]self.class_prob = dict(zip(self.video_infos_by_class, class_prob))
1.在
BaseDataset
类中,repeat
参数的主要作用是控制数据的重复使用次数,通常用于数据增强或者调整数据集的有效规模2.
sample_by_class:
在一些数据集里,不同类别的样本数量可能存在较大差异,即存在类别不平衡问题。按类别采样可以确保每个类别在训练过程中都有足够的样本被使用,避免模型过度偏向样本数量多的类别,有助于提高模型对各个类别的分类性能。3
power
:作用:当sample_by_class
为True
时,不同类别的采样概率最初是根据该类别样本数量占总样本数量的比例来计算的。使用power
参数可以对这些概率进行调整。如果power
大于 1,会增大样本数量少的类别的采样概率,使得这些类别在采样中更有可能被选中;如果power
小于 1 且大于 0,会减小样本数量少的类别的采样概率;当power
等于 0 时,所有类别的采样概率相等。4
dynamic_length:
在一些数据集中,数据样本的长度可能是不同的。例如,在处理文本数据时,不同句子的长度可能不一样;在处理视频数据时,不同视频的帧数也可能不同。使用动态长度的数据可以更灵活地处理这些情况,避免对数据进行不必要的截断或填充操作,从而保留更多的数据信息。但同时,动态长度的数据处理起来相对复杂,需要特殊的处理机制。
2.成员函数
1.处理标注文件 返回一个列表
@abstractmethoddef load_annotations(self):"""Load the annotation according to ann_file into video_infos."""# json annotations already looks like video_infos, so for each dataset,# this func should be the samevideo_infos = []with open(self.ann_file, 'r') as fin:for line in fin:line_split = line.strip().split()if self.multi_class:assert self.num_classes is not Nonefilename, label = line_split[0], line_split[1:]label = list(map(int, label))else:filename, label = line_splitlabel = int(label)if self.data_prefix is not None:filename = osp.join(self.data_prefix, filename)video_infos.append(dict(filename=filename, label=label, tar=self.use_tar_format))return video_infos
[{'filename': 'videos/sample_01.mp4', 'label': 0, 'tar': False},{'filename': 'videos/sample_02.mp4', 'label': 1, 'tar': True}
]
2.parse_by_class
将 video_infos
中的数据按类别进行分组,返回一个字典,键为类别标签,值为该类别对应的视频信息列表
def parse_by_class(self):video_infos_by_class = defaultdict(list)#defaultdict 是 Python 中 collections 模块提供的一个特殊字典,当访问一个不存在的键时,它会自动创建一个默认值。这里指定默认值为一个空列表 list。for item in self.video_infos:label = item['label']video_infos_by_class[label].append(item)return video_infos_by_class
3.def label2array(num, label) one-hot数据
def label2array(num, label):arr = np.zeros(num, dtype=np.float32)arr[label] = 1.return arr
4.dump_results 储存数据
@staticmethoddef dump_results(results, out):"""Dump data to json/yaml/pickle strings or files."""return mmcv.dump(results, out)
5.准备训练帧( self.video_infos = self.load_annotations())他是一个列表
def prepare_train_frames(self, idx):"""Prepare the frames for training given the index."""results = copy.deepcopy(self.video_infos[idx])results['modality'] = self.modalityresults['start_index'] = self.start_index# prepare tensor in getitem# If HVU, type(results['label']) is dictif self.multi_class and isinstance(results['label'], list):onehot = torch.zeros(self.num_classes)onehot[results['label']] = 1.results['label'] = onehotaug1 = self.pipeline(results)if self.repeat > 1:aug2 = self.pipeline(results)ret = {"imgs": torch.cat((aug1['imgs'], aug2['imgs']), 0),"label": aug1['label'].repeat(2),}return retelse:return aug1def prepare_test_frames(self, idx):"""Prepare the frames for testing given the index."""results = copy.deepcopy(self.video_infos[idx])results['modality'] = self.modalityresults['start_index'] = self.start_index# prepare tensor in getitem# If HVU, type(results['label']) is dictif self.multi_class and isinstance(results['label'], list):onehot = torch.zeros(self.num_classes)onehot[results['label']] = 1.results['label'] = onehotreturn self.pipeline(results)
子类:处理video
class VideoDataset(BaseDataset):def __init__(self, ann_file, pipeline, labels_file, start_index=0, **kwargs):super().__init__(ann_file, pipeline, start_index=start_index, **kwargs)self.labels_file = labels_file@propertydef classes(self):classes_all = pd.read_csv(self.labels_file)return classes_all.values.tolist()def load_annotations(self):"""Load annotation file to get video information."""if self.ann_file.endswith('.json'):return self.load_json_annotations()video_infos = []with open(self.ann_file, 'r') as fin:for line in fin:line_split = line.strip().split()if self.multi_class:assert self.num_classes is not Nonefilename, label = line_split[0], line_split[1:]label = list(map(int, label))else:filename, label = line_splitlabel = int(label)if self.data_prefix is not None:filename = osp.join(self.data_prefix, filename)video_infos.append(dict(filename=filename, label=label, tar=self.use_tar_format))return video_infos
最终:
train_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT,labels_file=config.DATA.LABEL_LIST, pipeline=train_pipeline)
怎么看数据集返回什么? 看子类实现了了__getitem__(video没有实现,去上面找发现
返回了管道操作(等下数据增强的时候讲)之后的数据
def __getitem__(self, idx):"""Get the sample for either training or testing given index."""if self.test_mode:return self.prepare_test_frames(idx)return self.prepare_train_frames(idx)
->def prepare_test_frames(self, idx):"""Prepare the frames for testing given the index."""results = copy.deepcopy(self.video_infos[idx])results['modality'] = self.modalityresults['start_index'] = self.start_index# prepare tensor in getitem# If HVU, type(results['label']) is dictif self.multi_class and isinstance(results['label'], list):onehot = torch.zeros(self.num_classes)onehot[results['label']] = 1.results['label'] = onehotreturn self.pipeline(results)->self.pipeline = Compose(pipeline)
->train_pipeline = [dict(type='DecordInit'),dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES),dict(type='DecordDecode'),dict(type='Resize', scale=(-1, scale_resize)),dict(type='MultiScaleCrop',input_size=config.DATA.INPUT_SIZE,scales=(1, 0.875, 0.75, 0.66),random_crop=False,max_wh_scale_gap=1),dict(type='Resize', scale=(config.DATA.INPUT_SIZE, config.DATA.INPUT_SIZE), keep_ratio=False),dict(type='Flip', flip_ratio=0.5),dict(type='ColorJitter', p=config.AUG.COLOR_JITTER),dict(type='GrayScale', p=config.AUG.GRAY_SCALE),dict(type='Normalize', **img_norm_cfg),dict(type='FormatShape', input_format='NCHW'),dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),dict(type='ToTensor', keys=['imgs', 'label']),]->train_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT,labels_file=config.DATA.LABEL_LIST, pipeline=train_pipeline)
结果返回了:返回一个字典:images:包含一个[N C H W]为一个采样帧数,labels:为一个独热张量
Dataloader(批次)
train_loader = DataLoader(train_data, sampler=sampler_train,batch_size=config.TRAIN.BATCH_SIZE,num_workers=16,pin_memory=True,drop_last=True,collate_fn=partial(mmcv_collate, samples_per_gpu=config.TRAIN.BATCH_SIZE),)
参数解析:
dataset
类型:torch.utils.data.Dataset 子类的实例
作用:指定要加载的数据集,例如前面提到的自定义数据集类的实例。
batch_size
类型:int
作用:每个批次加载的样本数量,默认为 1。例如,如果 batch_size = 32,则每次从数据集中加载 32 个样本。
num_workers
类型:int
作用:使用的子进程数量来加载数据。设置为 0 表示数据将在主进程中加载,
大于 0 则使用多进程并行加载数据,提高数据加载速度,默认为 0
sampler
类型:torch.utils.data.Sampler 子类的实例
作用:自定义样本采样策略。如果指定了 sampler,则 shuffle 参数将被忽略。例如,可以使用 WeightedRandomSampler 进行加权随机采样。
collate_fn
类型:callable
作用:自定义批量数据整理函数,用于将多个样本组合成一个批次。
例如前面提到的 mmcv_collate 函数。如果不指定,将使用默认的整理函数。
drop_last
类型:bool
作用:如果数据集的样本数量不能被 batch_size 整除,是否丢弃最后一个不完整的批次。设置为 True 则丢弃,
设置为 False 则保留,默认为 False
返回值:
DataLoader
是一个可迭代对象,当使用 for
循环遍历 DataLoader
时,每次迭代会返回一个批次的数据。返回的数据格式取决于 dataset
的 __getitem__
方法和 collate_fn
函数。
def mmcv_collate(batch, samples_per_gpu=1): if not isinstance(batch, Sequence):raise TypeError(f'{batch.dtype} is not supported.')if isinstance(batch[0], Sequence):transposed = zip(*batch)return [collate(samples, samples_per_gpu) for samples in transposed]elif isinstance(batch[0], Mapping):return {key: mmcv_collate([d[key] for d in batch], samples_per_gpu)for key in batch[0]}else:return default_collate(batch)
mmcv_collate
函数将多个样本组合成一个批次,最终输出的 batch
是一个字典,包含 'img'
和 'label'
两个键,对应的值分别是批量处理后的图像数据和标签数据
常见的返回形式有元组(包含输入数据和标签)
这个是返回:字典{images:[B,T,C,H,W],labels:}
使用方法
字典:
for idx, batch_data in enumerate(train_loader):images = batch_data["imgs"].cuda(non_blocking=True)label_id = batch_data["label"].cuda(non_blocking=True)
元组:
for iii, (image, class_id) in enumerate(tqdm(val_loader)):
图像处理->视频
images = images.view((-1,config.DATA.NUM_FRAMES,3)+images.size()[-2:])
由NCHW ->N T C H W 记住这个!!!