1 get_dataset
2 list_dataset.py/ListDataset
from torch.utils.data import Datasetclass ListDataset(Dataset):def __init__(self, data):"""data: 必须是一个 list"""self.data = datadef __getitem__(self, index):return self.data[index]def __len__(self):return len(self.data)
__init__(self, data)
:构造函数接收一个参数data
,这个参数必须是一个列表。在这个构造函数中,传入的列表被保存在实例变量self.data
中。-
__getitem__(self, index)
:这是一个特殊方法,允许类的实例像列表一样通过索引访问。该方法返回self.data
中指定索引index
的元素。这是 PyTorchDataset
接口的一部分,使得每个元素可以通过索引直接访问,非常适合于训练过程中获取单个训练样本。 -
__len__(self)
:这也是一个特殊方法,它返回self.data
列表的长度,即数据集中的元素总数。这个方法提供了数据集的大小,是DataLoader
在数据加载过程中进行批处理、采样等操作时需要用到的信息。
3 generate_data
pytorch笔记:Dataloader_tensor如何装进dataloader里-CSDN博客