目录
- 1. 步骤
- 2. 加载数据
- 2.1 继承Dataset
- 2.1.1 生成name2label
- 2.1.2 生成image path, label的文件
- 2.1.3 __len__
- 2.1.3 __getitem__
- 2.1.4 数据切分为train、val、test
- 3. 建立模型
- 4. 训练和测试
- 4. 完整代码
1. 步骤
- 加载数据
- 创建模型
- 训练和测试
- 迁移学习
2. 加载数据
这里以宝可梦动画图片为数据集
下载地址:
链接:https://pan.baidu.com/s/1TbXKNIBitXk_o-oVAAiX-A?pwd=py3r
提取码:py3r
数据集各分类情况和切分比例见下图:
2.1 继承Dataset
继承torch.utils.data.Dataset,实现__len__和__getitem__函数
__len__是获取所有数据集的数量
__getitem__获取数据集中指定index的image tensor和对应的分类label
实现这两个函数的思路:
- 将数据集所有文件名加载到list中,通过len([images]),即可实现__len__
- 生成数据集中所有的image path和label,读取并预处理image,即可实现__getitem__
2.1.1 生成name2label
数据集文件结构是pokemon\bulbasaur\00000000.png,pokemon下的每个文件夹代表一个分类,因此就可以实现下面的代码生成一个name2label
self.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())
2.1.2 生成image path, label的文件
获取pokemon目前下所有数据文件的路径放到images中,遍历images,通过每条数据文件路径中的分类文件夹名称从name2label获取到对应的label,然后写入到文件中。
代码如下:
if not os.path.exists(os.path.join(self.root, filename)):images = []for name in self.name2label.keys():# 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))# 1167, 'pokemon\\bulbasaur\\00000000.png'print(len(images), images)random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2]label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label])print('writen into csv file:', filename)
2.1.3 len
def __len__(self):return len(self.images)
2.1.3 getitem
预处理包括resize、randomRotation、ToTensor、Normalize等
def __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label)return img, label
2.1.4 数据切分为train、val、test
数据切分比例6:2:2
class Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv')if mode=='train': # 60%self.images = self.images[:int(0.6*len(self.images))]self.labels = self.labels[:int(0.6*len(self.labels))]elif mode=='val': # 20% = 60%->80%self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else: # 20% = 80%->100%self.images = self.images[int(0.8*len(self.images)):]self.labels = self.labels[int(0.8*len(self.labels)):]
3. 建立模型
使用前面实现的Resnet18https://editor.csdn.net/md/?articleId=140032483
4. 训练和测试
- 数据加载
- 实例化模型、优化器、loss函数
- epoch循环、模型数据输入、计算loss、backward、优化器迭代
- validation模型、保存模型
- test模型
代码:
def evalute(model, loader):model.eval()correct = 0total = len(loader.dataset)for x,y in loader:x,y = x.to(device), y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()return correct / totaldef ResNet18():return ResNet(ResBlk, [2, 2, 2, 2], 10)def main():model = ResNet18().to(device)optimizer = optim.Adam(model.parameters(), lr=lr)criteon = nn.CrossEntropyLoss()best_acc, best_epoch = 0, 0global_step = 0viz.line([0], [-1], win='loss', opts=dict(title='loss'))viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))for epoch in range(epochs):for step, (x,y) in enumerate(train_loader):# x: [b, 3, 224, 224], y: [b]x, y = x.to(device), y.to(device)model.train()logits = model(x)loss = criteon(logits, y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step += 1if epoch % 1 == 0:val_acc = evalute(model, val_loader)if val_acc> best_acc:best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')viz.line([val_acc], [global_step], win='val_acc', update='append')print('best acc:', best_acc, 'best epoch:', best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc = evalute(model, test_loader)print('test acc:', test_acc)
4. 完整代码
train.py
import torch
from torch import optim, nn
import visdom
import torchvision
from torch.utils.data import DataLoaderfrom pokemon import Pokemon
from resnet import ResNet18batchsz = 32
lr = 1e-3
epochs = 10device = torch.device('cuda')
torch.manual_seed(1234)train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)viz = visdom.Visdom()def evalute(model, loader):model.eval()correct = 0total = len(loader.dataset)for x,y in loader:x,y = x.to(device), y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()return correct / totaldef main():model = ResNet18().to(device)optimizer = optim.Adam(model.parameters(), lr=lr)criteon = nn.CrossEntropyLoss()best_acc, best_epoch = 0, 0global_step = 0viz.line([0], [-1], win='loss', opts=dict(title='loss'))viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))for epoch in range(epochs):for step, (x,y) in enumerate(train_loader):# x: [b, 3, 224, 224], y: [b]x, y = x.to(device), y.to(device)model.train()logits = model(x)loss = criteon(logits, y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step += 1if epoch % 1 == 0:val_acc = evalute(model, val_loader)if val_acc> best_acc:best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')viz.line([val_acc], [global_step], win='val_acc', update='append')print('best acc:', best_acc, 'best epoch:', best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc = evalute(model, test_loader)print('test acc:', test_acc)if __name__ == '__main__':main()
pokemon.py
import torch
import os, glob
import random, csvfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms
from PIL import Imageclass Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv')if mode=='train': # 60%self.images = self.images[:int(0.6*len(self.images))]self.labels = self.labels[:int(0.6*len(self.labels))]elif mode=='val': # 20% = 60%->80%self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else: # 20% = 80%->100%self.images = self.images[int(0.8*len(self.images)):]self.labels = self.labels[int(0.8*len(self.labels)):]def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)):images = []for name in self.name2label.keys():# 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))# 1167, 'pokemon\\bulbasaur\\00000000.png'print(len(images), images)random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2]label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label])print('writen into csv file:', filename)# read from csv fileimages, labels = [], []with open(os.path.join(self.root, filename)) as f:reader = csv.reader(f)for row in reader:# 'pokemon\\bulbasaur\\00000000.png', 0img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labelsdef __len__(self):return len(self.images)def denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hat = (x-mean)/std# x = x_hat*std = mean# x: [c, h, w]# mean: [3] => [3, 1, 1]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn xdef __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label)return img, labeldef main():import visdomimport timeimport torchvisionviz = visdom.Visdom()# tf = transforms.Compose([# transforms.Resize((64,64)),# transforms.ToTensor(),# ])# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)# loader = DataLoader(db, batch_size=32, shuffle=True)## print(db.class_to_idx)## for x,y in loader:# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))# viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))## time.sleep(10)db = Pokemon('pokemon', 64, 'train')x,y = next(iter(db))print('sample:', x.shape, y.shape, y)viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)for x,y in loader:viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))time.sleep(10)if __name__ == '__main__':main()