[PyTorch][chapter 50][创建自己的数据集 2]

前言:

      这里主要针对图像数据进行预处理.定义了一个 class Pokemon(Dataset) 类,实现

图像数据集加载,划分的基本方法.
 


目录:

  1.      整体框架
  2.       __init__ 
  3.      load_images
  4.       save_csv
  5.      divide_data
  6.      __len__
  7.      denormalize
  8.     __getitem__
  9.    main
  10.    ImageFolder 

     


一  整体框架

       我们需要创建一个自定义的数据集类,该类必须继承自Dataset类,

      重点实现以下三个方法:

      __init__

   __len__()

       __getitem__()


二  __init__ 

      实现了图像数据集的加载

      根据mode 进行划分

    def __init__(self, root, resize, mode,fileName):#初始化函数super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label ={}#遍历目录path = os.path.join(root)#用子目录文件夹名字作为分类keyfor name in sorted(os.listdir(path)):subDir = os.path.join(root, name)if not os.path.isdir(subDir):continueelse:self.name2label[name] = len(self.name2label.keys())csv_path = os.path.join(self.root, fileName)print("\n csv_path:  ",csv_path)if not os.path.exists(csv_path):images = self.load_images()self.save_csv(fileName, images)self.images, self.labels = self.load_csv(fileName)self.divide_data(mode)


三 load_images

    加载指定目录下面的图片,

   把图片路径保存到列表里面

  def load_images(self):images =[]for name in self.name2label.keys():#pokeon\\newtwoo\\00001.png#返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。下面是使用glob.glob的例子:pngPath = os.path.join(self.root, name,'*.png')jpgPath = os.path.join(self.root, name,'*.jpg')jpegPath = os.path.join(self.root, name,'*.jpeg')png = glob.glob(pngPath)jpg =glob.glob(jpgPath)jpeg = glob.glob(jpegPath)images +=jpgimages +=jpegimages +=pngprint("\n images ",len(images))random.shuffle(images)return images

四    save_csv

       图片路径,标签保存到csv 文件里面

   

       #image, labeldef save_csv(self, fileName, images):path = os.path.join(self.root, fileName)csvfile = open(path,mode='w',newline='')writer = csv.writer(csvfile)for img in images:name = img.split(os.sep)[-2]label = self.name2label[name]writer.writerow([img, label])csvfile.close()


四  load_csv

    加载 csv 文件

    def load_csv(self, fileName):path = os.path.join(self.root, fileName)csvfile = open(path,mode='r',newline='')reader = csv.reader(csvfile)images =[]labels =[]for row in reader:img, label = rowlabel = int(label)images.append(img)labels.append(label)m = len(images)n = len(labels)print("\n number images: %d number labels: %d"%(m,n))return  images,labels

五  divide_data

   数据集划分

    训练集: 60%

    验证集: 20%

    测试机:20%

    def divide_data(self,mode):N = len(self.images)if 'train' == mode: #0->60%start = 0end = int(0.6*N)elif 'val' == mode:#60%->80%start = int(0.6*N)end = int(0.8*N)else:#80%->100%start = int(0.8*N)end = Nself.images = self.images[start:end]self.labels = self.labels[start:end]m = len(self.images )print("\n number divide images: %d "%(m))

六      __len__

    返回数据集大小

    def __len__(self):#总的数据N = len(self.images)return N

七  denormalize

   图像数据 标准后,当需要显示原图片的时候,需要反标准化

   def denormalize(self,x_hat):#x_hat =(x-mean)/std#x = x_hat*std+mean#x: [c,h,w]#mean: [3]=>[3,1,1]mean=[0.485, 0.456, 0.406]std=[0.229, 0.224, 0.225]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std =  torch.tensor(std).unsqueeze(1).unsqueeze(1)x =x_hat*std+meanreturn x

八  __getitem__

   根据指定的索引获取对应的图片,以及标签值

        def __getitem__(self, index):#返回当前index 对应的图片数据#self.images, self.labels#idx ~[0,N]img_path = self.images[index] #图片路径label = self.labels[index] #图片标签#print("\n img_path",img_path)tf = transforms.Compose([  lambda x:Image.open(x).convert('RGB'),transforms.Resize((int(self.resize*1.25) , int(self.resize*1.25))), transforms.RandomRotation(15), transforms.ToTensor(),transforms.CenterCrop(self.resize),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img =  tf(img_path)label = torch.tensor(label)#print("\n index ",index, "\t img ",img.shape,"\t label ",label)return img, label

九  main 

1 先定义一个class Pokemon(Dataset): 类,并实现上面的方法

2    数据集的迭代加载,以及通过visdom 工具加载显示

def main():root ='pokemon'resize =224mode = 'test' #数据集分为三种 tain,val,testcsvfile ='data.csv'db = Pokemon(root, resize, mode,csvfile)viz = visdom.Visdom()# datetime转字符串time.time() #显示当前的时间戳curtime = time.strftime('%H:%M:%S') #结构化输出当前的时间BATCH_SIZE = 32loader = DataLoader(dataset = db, batch_size = BATCH_SIZE,shuffle = True)for step, (batchX, batchY) in enumerate(loader):print( '| Step: ', step, '| batch x: ',batchX.shape, '| batch y: ', batchY.shape)viz.images(db.denormalize(batchX),nrow=8, win='batchX',opts=dict(title=curtime))viz.text(str(batchY.numpy()),win='batchY',opts=dict(title='label'))time.sleep(10)if __name__ == "__main__" :main()

十  ImageFolder 

  自己的图像数据集如果有规律的话,可以直接用PyTorch API 函数实现 Pokemon

类的功能

from torchvision.datasets import ImageFolder
from torchvision import transformsimgMean =[0.485, 0.456, 0.406]
imgStd = [0.229, 0.224, 0.225]
normalize=transforms.Normalize(mean=imgMean,std=imgStd)
transform=transforms.Compose([transforms.RandomCrop(180),transforms.RandomHorizontalFlip(),transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]normalize
])dataset=ImageFolder('./data/train',transform=transform)

参考:

torchvision.datasets.ImageFolder使用详解_☞源仔的博客-CSDN博客

课时102 自定义数据集实战-5_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/90339.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构——堆

数据结构——堆 堆堆简介堆的分类 二叉堆过程插入操作 删除操作向下调整: 增加某个点的权值实现参考代码:建堆方法一:使用 decreasekey(即,向上调整)方法二:使用向下调整 应用对顶堆 其他&#…

dirsearch_暴力扫描网页结构

python3 dirsearch 暴力扫描网页结构(包括网页中的目录和文件) 下载地址:https://gitee.com/xiaozhu2022/dirsearch/repository/archive/master.zip 下载解压后,在dirsearch.py文件窗口,打开终端(任务栏…

深入理解索引B+树的基本原理

目录 1. 引言 2. 为什么要使用索引? 3. 索引的概述 4. 索引的优点是什么? 4.1 降低数据库的IO成本,提高数据查找效率 4.2 保证数据库每一行数据的唯一性 4.3 加速表与表之间的连接 4.4 减少查询中分组与排序的执行时间 5. 索引的缺点…

[足式机器人]Part3机构运动微分几何学分析与综合Ch03-1 空间约束曲线与约束曲面微分几何学——【读书笔记】

本文仅供学习使用 本文参考: 《机构运动微分几何学分析与综合》-王德伦、汪伟 《微分几何》吴大任 Ch01-4 平面运动微分几何学 3.1 空间曲线微分几何学概述3.1.1 矢量表示3.1.2 Frenet标架 连杆机构中的连杆与连架杆构成运动副,该运动副元素的特征点或特…

【Stable Diffusion】雨天、湿身

一、Models 1.1、Wet Clothes (Clothing Style) [LoHA] WECL SEE-THROUGH WET WET HAIR BIKINI OR SWIMSUIT UNDER CLOTHES NO BRA BRA VISIBLE THROUGH CLOTHES MISC SHIRTS MISC CLOTHES1.2、Rain 雨 Multiply Style rain style1.3、Wet T-Shirt LORA <lora:wetshirt:…

5.1 web浏览安全

数据参考&#xff1a;CISP官方 目录 Web应用基础浏览器所面临的安全威胁养成良好的Web浏览安全意识如何安全使用浏览器 一、Web应用基础 1、Web应用的基本概念 Web ( World wide Web) 也称为万维网 脱离单机Web应用在互联网上占据了及其重要的地位Web应用的发展&#xf…

最新Kali Linux安装教程:从零开始打造网络安全之旅

Kali Linux&#xff0c;全称为Kali Linux Distribution&#xff0c;是一个操作系统(2013-03-13诞生)&#xff0c;是一款基于Debian的Linux发行版&#xff0c;基于包含了约600个安全工具&#xff0c;省去了繁琐的安装、编译、配置、更新步骤&#xff0c;为所有工具运行提供了一个…

计算机竞赛 python 机器视觉 车牌识别 - opencv 深度学习 机器学习

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于python 机器视觉 的车牌识别系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;3分 &#x1f9ff; 更多资…

【量化课程】02_4.数理统计的基本概念

2.4_数理统计的基本概念 数理统计思维导图 更多详细内容见notebook 1.基本概念 总体&#xff1a;研究对象的全体&#xff0c;它是一个随机变量&#xff0c;用 X X X表示。 个体&#xff1a;组成总体的每个基本元素。 简单随机样本&#xff1a;来自总体 X X X的 n n n个相互…

梯度下降介绍

什么是梯度 梯度是微积分中一个很重要的概念&#xff0c;在单变量的函数中&#xff0c;梯度其实就是函数的微分&#xff0c;代表着函数在某个给定点的切线的斜率&#xff1b;在多变量函数中&#xff0c;梯度是一个向量&#xff0c;向量有方向&#xff0c;梯度的方向就指出了函…

809协议nodejs编写笔记(还在更新)

一、总体流程 数据首先通过receiver接受层接收&#xff0c;去掉标识头和标识尾&#xff1b;再进入depacker解包层进行解包&#xff0c;把标识头分解出来并解析&#xff1b;之后发给handler处理层根据不同的消息id选择使用不同的业务逻辑&#xff1b;如果有应答&#xff0c;则通…

陪诊小程序开发|陪诊陪护小程序让看病不再难

陪诊小程序通过与医疗机构的合作&#xff0c;整合了医疗资源&#xff0c;让用户能够更加方便地获得专业医疗服务。用户不再需要面对繁琐的挂号排队&#xff0c;只需通过小程序预约服务&#xff0c;便能够享受到合适的医疗资源。这使得用户的就医过程变得简单高效&#xff0c;并…

CSS中的position属性有哪些值,并分别描述它们的作用。

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ static⭐ relative⭐ absolute⭐ fixed⭐ sticky⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&#xff01;这个专栏是为那…

Tomcat多实例部署及nginx+tomcat的负载均衡和动静分离

Tomcat多实例部署 安装 jdk、tomcat&#xff08;流程可看之前博客&#xff09; 配置 tomcat 环境变量 [rootlocalhost ~]# vim /etc/profile.d/tomcat.sh#tomcat1 export CATALINA_HOME1/usr/local/tomcat/tomcat1 export CATALINA_BASE1/usr/local/tomcat/tomcat1 export T…

pdf怎么转换成jpg图片?这几个转换方法了解一下

pdf怎么转换成jpg图片&#xff1f;转换PDF文件为JPG图片格式在现代工作中是非常常见的需求&#xff0c;比如将PDF文件中的图表、表格或者图片转换为JPG格式后使用在PPT演示、网页设计等场景中。 【迅捷PDF转换器】是一款非常实用的工具&#xff0c;可以将PDF文件转换成多种不同…

HTML详解连载(5)

HTML详解连载&#xff08;5&#xff09; 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽行高&#xff1a;设置多行文本的间距属性名属性值行高的测量方法 行高-垂直居中技巧 字体族属性名属性值示例扩展 font 复合属性使用场景复合属性示例注意 文本缩进属性…

YOLO v8目标跟踪详细解读(二)

上一篇&#xff0c;结合代码&#xff0c;我们详细的介绍了YOLOV8目标跟踪的Pipeline。大家应该对跟踪的流程有了大致的了解&#xff0c;下面我们将对跟踪中出现的卡尔曼滤波进行解读。 1.卡尔曼滤波器介绍 卡尔曼滤波&#xff08;kalman Filtering&#xff09;是一种利用线性…

Python学习 -- 常用函数与实例详解

在Python编程中&#xff0c;数据转换是一项关键任务&#xff0c;它允许我们在不同数据类型之间自由流动&#xff0c;从而提高代码的灵活性和效率。本篇博客将深入探讨常用的数据转换函数&#xff0c;并通过实际案例为你展示如何巧妙地在不同数据类型之间转换。 数据类型转换函…

分布式监控平台——Zabbix

市场上常用的监控软件&#xff1a; 传统运维&#xff1a;zabbix、 Nagios 一、zabbix概述 作为一个运维&#xff0c;需要会使用监控系统查看服务器状态以及网站流量指标&#xff0c;利用监控系统的数据去了解上线发布的结果&#xff0c;和网站的健康状态。 利用一个优秀的监…

【数学建模】--因子分析模型

因子分析有斯皮尔曼在1904年首次提出&#xff0c;其在某种程度上可以被看成时主成分分析的推广和扩展。 因子分析法通过研究变量间的相关稀疏矩阵&#xff0c;把这些变量间错综复杂的关系归结成少数几个综合因子&#xff0c;由于归结出的因子个数少于原始变量的个数&#xff0c…