我们之前用torchvision加载了pytorch的网络数据集,现在我们用Dataset加载自己的数据集,并且使用DataLoader做成训练数据集。
图像是从网上下载的,网址是 点这里,标签是图像文件夹名字。下载完成后作为自己的数据集。
1.加载自己的数据集的思路
1)要完成继承自Dataset的类的构建
由于Dataset是一个包含了虚函数的类,因此继承Dataset后,必须实现这些虚函数。
2)第一个要完成的是__init__的构建,一般的方法是在__init__(self,root_dir, label_dir)中设置数据集的根目录root_dir,和类别数据集label_dir,然后用os.listdir得到label_dir中的图像名字
3)第二个要完成的就是
__getitem__(self, item):
item就是所要取数据的索引,这个函数主要是返回一个训练数据(比如一个图像),和一个结果数据,比如(该图像的分类结果是一个ant),因此用到刚os.listdir所列出的文件名字,用os.path.join加入路径,得到图像的绝对路径,用PIL导入图像,并给label赋值,返回图像和;abel即可。
4)第三个要实现的就是数据集的长度
__len__(self):
可以直接len(os.listdir所列出的文件名的数组),就可以得到数据集的长度。
2.需要注意的问题
我在调试的时候发现
for imgs, labels in train_loader:
一直报错,查找原因,发现是该数据集中的图像存在两个问题,第一个是大小不一,第二个貌似通道个数也不一致。
大小不一
因此使用transform做了处理
transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),transforms.Grayscale(),transforms.ToTensor()])
3.代码如下:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")
transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),transforms.Grayscale(),transforms.ToTensor()])class MyDataLoader(Dataset):def __init__(self,root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir,self.label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, item):img_name = self.img_path[item]img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)img = Image.open(img_item_path)img = transform(img)label = self.label_dirreturn img,labeldef __len__(self):return len(self.img_path)root_dir = "E:/TOOLE/slam_evo/pythonProject/data/hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"ants_dataset = MyDataLoader(root_dir,ants_label_dir)
bees_dataset = MyDataLoader(root_dir,bees_label_dir)
train_data = ants_dataset + bees_datasetimg0, label0 = train_data[12]
# img0.show()
img1, label1 = train_data[124]
# img1.show()
# 一次处理数据10个
BATCH_SIZE = 10
# 把数据集装载到DataLoader里
train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)A = len(train_loader)
num_iter = 0
for imgs, labels in train_loader:print(imgs.shape)print(labels)# print(train_data.classes)writer.add_images("ant-bees",imgs,num_iter)num_iter = num_iter +1writer.close()
用tensorboard显示,batch_size= 10,因此每次迭代有10张图像
标签为: