blog
torch.utils.data.Dataset
- create dataset with class torch.utils.data.Dataset automaticly
import torch
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):# 根据索引获取样本return self.data[index]def __len__(self):# 返回数据集大小return len(self.data)# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)# 根据索引获取样本
sample = dataset[2]
print(sample)
torchvision.datasets
- load data from classic dataset
import torch
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(), # 将图像转换为张量transforms.Normalize((0.5,), (0.5,)) # 标准化图像
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
2. load data from Imagefolder with transform
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# transform.Compose是PyTorch中的一个类,用于将多个图像变换操作组合在一起。它的作用是将这些操作按照顺序依次应用于输入的图像数据。
trans = transforms.Compose([np.float32,transforms.ToTensor(),fixed_image_standardization
])dataset = datasets.ImageFolder(data_dir, transform=trans)
loader = DataLoader(dataset,num_workers=workers,batch_size=batch_size,collate_fn=training.collate_pil
)
3. Introduction of Imagefolder
# 定义输入图像的数据加载器
mytransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor(),])
dataset = datasets.ImageFolder(data_dir, transform=mytransform) #对于图像,必须transform转成Tensor,才能for input,label in train_loader读取
print(dataset)
print(len(dataset))
print(len(dataset.imgs))
print(len(dataset.classes))
print(dataset.classes[-1])
print(dataset.classes)
print(dataset.imgs)
\root\cls1\img1.png\img2.png\cls2\img1.png\img2.png\cls3\img1.png\img2.png
(img,cls) in dataset.imgs
# img_list_1=[img for (img,idx) in dataset.imgs]
# with open("img_list_1.pkl","wb") as file:
# pickle.dump(img_list_1,file)
DataLoader
- to loader data from example of torch.utils.data.Dataset
import torch
from torchvision import datasets, transforms# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)# 使用数据加载器迭代样本
for images, labels in train_loader:# 训练模型的代码...
num_workers
link:加载 in batch的进程数
torchvision.transforms
from torchvision import transforms# 定义图像预处理操作
transform = transforms.Compose([transforms.Resize((256, 256)), # 缩放图像大小为 (256, 256)transforms.RandomCrop((224, 224)), # 随机裁剪图像为 (224, 224)transforms.RandomHorizontalFlip(), # 随机水平翻转图像transforms.ToTensor(), # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化图像
])# 对图像进行预处理
image = transform(image)