pytorch-训练自定义数据集实战

目录

  • 1. 步骤
  • 2. 加载数据
    • 2.1 继承Dataset
      • 2.1.1 生成name2label
      • 2.1.2 生成image path, label的文件
      • 2.1.3 __len__
      • 2.1.3 __getitem__
      • 2.1.4 数据切分为train、val、test
  • 3. 建立模型
  • 4. 训练和测试
  • 4. 完整代码

1. 步骤

  • 加载数据
  • 创建模型
  • 训练和测试
  • 迁移学习

2. 加载数据

这里以宝可梦动画图片为数据集
在这里插入图片描述
下载地址:
链接:https://pan.baidu.com/s/1TbXKNIBitXk_o-oVAAiX-A?pwd=py3r
提取码:py3r

数据集各分类情况和切分比例见下图:
在这里插入图片描述

2.1 继承Dataset

继承torch.utils.data.Dataset,实现__len__和__getitem__函数
__len__是获取所有数据集的数量
__getitem__获取数据集中指定index的image tensor和对应的分类label
实现这两个函数的思路:

  • 将数据集所有文件名加载到list中,通过len([images]),即可实现__len__
  • 生成数据集中所有的image path和label,读取并预处理image,即可实现__getitem__

2.1.1 生成name2label

数据集文件结构是pokemon\bulbasaur\00000000.png,pokemon下的每个文件夹代表一个分类,因此就可以实现下面的代码生成一个name2label

 self.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())

2.1.2 生成image path, label的文件

获取pokemon目前下所有数据文件的路径放到images中,遍历images,通过每条数据文件路径中的分类文件夹名称从name2label获取到对应的label,然后写入到文件中。
代码如下:

       if not os.path.exists(os.path.join(self.root, filename)):images = []for name in self.name2label.keys():# 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))# 1167, 'pokemon\\bulbasaur\\00000000.png'print(len(images), images)random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2]label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label])print('writen into csv file:', filename)

2.1.3 len

    def __len__(self):return len(self.images)

2.1.3 getitem

预处理包括resize、randomRotation、ToTensor、Normalize等

def __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label)return img, label

2.1.4 数据切分为train、val、test

数据切分比例6:2:2

class Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv')if mode=='train': # 60%self.images = self.images[:int(0.6*len(self.images))]self.labels = self.labels[:int(0.6*len(self.labels))]elif mode=='val': # 20% = 60%->80%self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else: # 20% = 80%->100%self.images = self.images[int(0.8*len(self.images)):]self.labels = self.labels[int(0.8*len(self.labels)):]

3. 建立模型

使用前面实现的Resnet18https://editor.csdn.net/md/?articleId=140032483

4. 训练和测试

  • 数据加载
  • 实例化模型、优化器、loss函数
  • epoch循环、模型数据输入、计算loss、backward、优化器迭代
  • validation模型、保存模型
  • test模型
    代码:
def evalute(model, loader):model.eval()correct = 0total = len(loader.dataset)for x,y in loader:x,y = x.to(device), y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()return correct / totaldef ResNet18():return ResNet(ResBlk, [2, 2, 2, 2], 10)def main():model = ResNet18().to(device)optimizer = optim.Adam(model.parameters(), lr=lr)criteon = nn.CrossEntropyLoss()best_acc, best_epoch = 0, 0global_step = 0viz.line([0], [-1], win='loss', opts=dict(title='loss'))viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))for epoch in range(epochs):for step, (x,y) in enumerate(train_loader):# x: [b, 3, 224, 224], y: [b]x, y = x.to(device), y.to(device)model.train()logits = model(x)loss = criteon(logits, y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step += 1if epoch % 1 == 0:val_acc = evalute(model, val_loader)if val_acc> best_acc:best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')viz.line([val_acc], [global_step], win='val_acc', update='append')print('best acc:', best_acc, 'best epoch:', best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc = evalute(model, test_loader)print('test acc:', test_acc)

4. 完整代码

train.py

import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoaderfrom    pokemon import Pokemon
from    resnet import ResNet18batchsz = 32
lr = 1e-3
epochs = 10device = torch.device('cuda')
torch.manual_seed(1234)train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)viz = visdom.Visdom()def evalute(model, loader):model.eval()correct = 0total = len(loader.dataset)for x,y in loader:x,y = x.to(device), y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()return correct / totaldef main():model = ResNet18().to(device)optimizer = optim.Adam(model.parameters(), lr=lr)criteon = nn.CrossEntropyLoss()best_acc, best_epoch = 0, 0global_step = 0viz.line([0], [-1], win='loss', opts=dict(title='loss'))viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))for epoch in range(epochs):for step, (x,y) in enumerate(train_loader):# x: [b, 3, 224, 224], y: [b]x, y = x.to(device), y.to(device)model.train()logits = model(x)loss = criteon(logits, y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step += 1if epoch % 1 == 0:val_acc = evalute(model, val_loader)if val_acc> best_acc:best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')viz.line([val_acc], [global_step], win='val_acc', update='append')print('best acc:', best_acc, 'best epoch:', best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc = evalute(model, test_loader)print('test acc:', test_acc)if __name__ == '__main__':main()

pokemon.py

import  torch
import  os, glob
import  random, csvfrom    torch.utils.data import Dataset, DataLoaderfrom    torchvision import transforms
from    PIL import Imageclass Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv')if mode=='train': # 60%self.images = self.images[:int(0.6*len(self.images))]self.labels = self.labels[:int(0.6*len(self.labels))]elif mode=='val': # 20% = 60%->80%self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else: # 20% = 80%->100%self.images = self.images[int(0.8*len(self.images)):]self.labels = self.labels[int(0.8*len(self.labels)):]def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)):images = []for name in self.name2label.keys():# 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))# 1167, 'pokemon\\bulbasaur\\00000000.png'print(len(images), images)random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2]label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label])print('writen into csv file:', filename)# read from csv fileimages, labels = [], []with open(os.path.join(self.root, filename)) as f:reader = csv.reader(f)for row in reader:# 'pokemon\\bulbasaur\\00000000.png', 0img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labelsdef __len__(self):return len(self.images)def denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hat = (x-mean)/std# x = x_hat*std = mean# x: [c, h, w]# mean: [3] => [3, 1, 1]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn xdef __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label)return img, labeldef main():import  visdomimport  timeimport  torchvisionviz = visdom.Visdom()# tf = transforms.Compose([#                 transforms.Resize((64,64)),#                 transforms.ToTensor(),# ])# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)# loader = DataLoader(db, batch_size=32, shuffle=True)## print(db.class_to_idx)## for x,y in loader:#     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))#     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))##     time.sleep(10)db = Pokemon('pokemon', 64, 'train')x,y = next(iter(db))print('sample:', x.shape, y.shape, y)viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)for x,y in loader:viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))time.sleep(10)if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/385422.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Minos 多主机分布式 docker-compose 集群部署

参考 docker-compose搭建多主机分布式minio - 会bk的鱼 - 博客园 (cnblogs.com) 【运维】docker-compose安装minio集群-CSDN博客 Minio 是个基于 Golang 编写的开源对象存储套件,虽然轻量,却拥有着不错的性能 中文地址:MinIO | 用于AI的S3 …

自学JavaScript(放假在家自学第一天)

目录 JavaScript介绍分为以下几点 1.1 JavaScript 是什么 1.2JavaScript书写位置 1.3 Javascript注释 1.4 Javascript结束符 1.5 Javascript输入输出语法 JavaScript(是什么?) 是一种运行在客户端(浏览器)的编程语言,实现人机交互效果。 2.作用(做什么?)网…

PCL-基于超体聚类的LCCP点云分割

目录 一、LCCP方法二、代码实现三、实验结果四、总结五、相关链接 一、LCCP方法 LCCP指的是Local Convexity-Constrained Patch,即局部凸约束补丁的意思。LCCP方法的基本思想是在图像中找到局部区域内的凸结构,并将这些结构用于分割图像或提取特征。这种…

入门 PyQt6 看过来(案例)13~ 制作一个颜色调节器

本文给大家带来一个利用pyqt制作的颜色调节器,通过拨动滚动条或者旋钮就可以调整rgb三色进行颜色的微调,效果如下: 本文实现的是不同的UI设计,实现的相同的功能,我们先分析以下思路: 首先进行UI页面设计分析…

SSL/TLS和SSL VPN

1、SSL/TLS SSL安全套接字层:是一种加密协议,用于在网络通信中建立安全连接。它在应用层和传输层(TCP/IP)之间提供数据加密、服务器身份验证以及信息完整性验证 SSL只保护TCP流量,不保护UDP协议 TLS:传输层…

VulnHub:cengbox1

靶机下载地址,下载完成后,用VirtualBox打开靶机并修改网络为桥接即可搭建成功。 信息收集 主机发现和端口扫描 扫描攻击机(192.168.31.218)同网段存活主机确认目标机ip,并对目标机进行全面扫描。 nmap 192.168.31.…

【VS2019安装+QT配置】

【VS2019安装QT配置】 1. 前言2. 下载visual studio20193. visual studio2019安装4. 环境配置4.1 系统环境变量配置4.2 qt插件开发 5. Visual Studio导入QT项目6. 总结 1. 前言 前期安装了qt,发现creator编辑器并不好用,一点都不时髦。在李大师的指导下&…

[网鼎杯 2020 朱雀组]Nmap(详细解读版)

这道题考察nmap的一些用法,以及escapeshellarg和escapeshellcmd两个函数的绕过,可以看这里PHP escapeshellarg()escapeshellcmd() 之殇 (seebug.org) 两种解题方法: 第一种通过nmap的-iL参数读取扫描一个文件到指定文件中第二种是利用nmap的参数写入we…

昇思25天学习打卡营第1天|快速入门-构建基于MNIST数据集的手写数字识别模型

非常感谢华为昇思大模型平台和CSDN邀请体验昇思大模型!从今天起,我将以打卡的方式,结合原文搬运和个人思考,分享25天的学习内容与成果。为了提升文章质量和阅读体验,我会将思考部分放在最后,供大家探索讨论…

java-数据结构与算法-02-数据结构-05-栈

文章目录 1. 栈1. 概述2. 链表实现3. 数组实现4. 应用 2. 习题E01. 有效的括号-Leetcode 20E02. 后缀表达式求值-Leetcode 120E03. 中缀表达式转后缀E04. 双栈模拟队列-Leetcode 232E05. 单队列模拟栈-Leetcode 225 1. 栈 1. 概述 计算机科学中,stack 是一种线性的…

[python游戏开发]用Python代码制作中国象棋游戏,适合新手小白练手

Pygame 做的中国象棋,一直以来喜欢下象棋,写了 python 就拿来做一个试试,水平有限,希望源码能帮助大家更好的学习 python。总共分为四个文件,chinachess.py 为主文件,constants.py 数据常量,pie…

新版海螺影视主题模板M3.1全解密版本多功能苹果CMSv10后台自适应主题

苹果CMS2022新版海螺影视主题M3.1版本,这个主题我挺喜欢的,之前也有朋友给我提供过原版主题,一直想要破解但是后来找了几个SG11解密的大哥都表示解密需要大几百大洋,所以一直被搁置了。这个版本是完全解密的,无需SG11加…

前端模块化CommonJS、AMD、CMD、ES6

在前端开发中,模块化是一种重要的代码组织方式,它有助于将复杂的代码拆分成可管理的小块,提高代码的可维护性和可重用性。CommonJS、AMD(异步模块定义)和CMD(通用模块定义)是三种不同的模块规范…

1、hadoop环境搭建

1、环境配置 ip(/etc/sysconfig/network-scripts) # 网卡1 DEVICEeht0 TYPEEthernet ONBOOTyes NM_CONTROLLEDyes BOOTPROTOstatic IPADDR192.168.59.11 GATEWAY192.168.59.1 NETMASK 255.255.255.0 # 网卡2 DEVICEeht0 TYPEEthernet ONBOOTyes NM_CONTROLLEDyes BOOTPROTOdh…

【React1】React概述、基本使用、脚手架、JSX、组件

文章目录 1. React基础1.1 React 概述1.1.1 什么是React1.1.2 React 的特点声明式基于组件学习一次,随处使用1.2 React 的基本使用1.2.1 React的安装1.2.2 React的使用1.2.3 React常用方法说明React.createElement()ReactDOM.render()1.3 React 脚手架的使用1.3.1 React 脚手架…

基于tkinter的学生信息管理系统之登录界面和主界面菜单设计

目录 一、tkinter的介绍 二、登陆界面的设计 1、登陆界面完整代码 2、部分代码讲解 3、登录的数据模型设计 4、效果展示 三、学生主界面菜单设计 1、学生主界面菜单设计完整代码 2、 部分代码讲解 3、效果展示 四、数据库的模型设计 欢迎大家进来学习和支持&#xff01…

从食堂采购系统源码到成品:打造供应链采购管理平台实战详解

本篇文章,笔者将详细介绍如何从食堂采购系统的源码开始,逐步打造一个完备的供应链采购管理平台,帮助企业实现采购流程的智能化和高效化。 一、需求分析与规划 一般来说,食堂采购系统需要具备以下基本功能: 1.供应商…

第15周 Zookeeper分布式锁与变种多级缓存

1. Zookeeper介绍 1.1 介绍 1.2 应用场景简介 1.3 zookeeper工作原理 1.4 zookeeper特点

AI的欺骗游戏:揭示多模态大型语言模型的易受骗性

人工智能咨询培训老师叶梓 转载标明出处 多模态大型语言模型(MLLMs)在处理包含欺骗性信息的提示时容易生成幻觉式响应。尤其是在生成长响应时,仍然是一个未被充分研究的问题。来自 Apple 公司的研究团队提出了MAD-Bench,一个包含8…

DLMS/COSEM中公开密钥算法的使用_椭圆曲线加密法

1.概述 椭圆曲线密码涉及有限域上的椭圆曲线上的算术运算。椭圆曲线可以定义在任何数字域上(实数、整数、复数),但在密码学中,椭圆曲线最常用于有限素数域。 素数域上的椭圆曲线由一组实数(x, y)组成,满足以下等式: 方程的所有解的集合构成…