pytorch中dataloader自定义数据集

前言

在深度学习中我们需要使用自己的数据集做训练,因此需要将自定义的数据和标签加载到pytorch里面的dataloader里,也就是自实现一个dataloader。

数据集处理

以花卉识别项目为例,我们分别做出图片的训练集和测试集,训练集的标签和测试集的标签

flower_data/
├── train_filelist/
│   ├── image_0001.jpg
│   └── ...
├── val_filelist/
│   ├── image_1001.jpg
│   └── ...
├── train.txt  # 格式:文件名 标签
└── val.txt

 数据目录的组织方式如上所示。

首先看图片的处理。图片只要做好编号放在同一个文件夹里就好了。

再看标签的处理。标签处理我们自己规定了一种形式,就是图像文件的名称+空格+分类标签。

可以看到前面第一列数据是图像名称,第二列数据是图像的分组,同样的数字为一组。比如分组为0的图像就是同一种花朵。

自定义dataset

源码

import os.path
import numpy as np
import torch
from PIL import Image  # 从PIL库导入Image类
from torch.utils.data import Datasetclass FlowerDataSet(Dataset):"""花朵分类任务数据集类,继承自torch的Dataset类"""def __init__(self, root_dir, ann_file, transform=None):"""初始化数据集实例Args:root_dir (str): 数据集根目录路径ann_file (str): 标注文件路径transform (callable, optional): 数据预处理变换函数"""self.ann_file = ann_fileself.root_dir = root_dir# 加载图片路径与标签的映射字典 {文件名: 标签}self.image_label = self.load_annotations()# 构建完整图片路径列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 构建标签列表 [标签1, 标签2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突self.transform = transformdef __len__(self):"""返回数据集样本数量"""return len(self.image)def __getitem__(self, index):"""获取单个样本数据Args:index (int): 样本索引Returns:tuple: (预处理后的图像数据, 对应的标签)"""# 打开图片文件image = Image.open(self.image[index])# 获取对应标签label = self.label[index]# 应用数据预处理if self.transform:image = self.transform(image)# 将标签转换为torch张量label = torch.from_numpy(np.array(label))return image, labeldef load_annotations(self):"""加载标注文件,解析图片文件名和标签的映射关系Returns:dict: {图片文件名: 对应标签} 的字典"""data_infos = {}with open(self.ann_file) as f:# 读取所有行并分割,每行格式应为 "文件名 标签"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 将标签转换为int64类型的numpy数组data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

解析

1、将标签数据进行读取,组成一个哈希表,哈希表的键是图像的文件名称,哈希表的值是分组标签。

    def load_annotations(self):"""加载标注文件,解析图片文件名和标签的映射关系Returns:dict: {图片文件名: 对应标签} 的字典"""data_infos = {}with open(self.ann_file) as f:# 读取所有行并分割,每行格式应为 "文件名 标签"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 将标签转换为int64类型的numpy数组data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

上面的代码里,在录入标签的时候使用数组进行记录,这是为了兼容多标签的场景。如果不考虑兼容问题,仅考虑在单标签场景下的简单实现,可以用下面的代码:

def load_annotations(self):data_infos = {}with open(self.ann_file) as f:for line in f:filename, label = line.strip().split()  # 直接解包data_infos[filename] = int(label)        # 存为 Python 整数return data_infos# 在 __getitem__ 中直接转为张量
label = torch.tensor(self.labels[index], dtype=torch.long)

2、遍历哈希表,将文件名和标签分别存在两个数组里。这里注意,为了方便后面dataloader按照batch去读取图片,这里要将图片的全路径加到文件名里。

        # 构建完整图片路径列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 构建标签列表 [标签1, 标签2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突

3、在dataloader向显卡/cpu加载数据的时候会调用getitem方法。比如一个batch里有64个数据,dataloader就会调用64次该方法,将64组图片和标签全部获取后交给运算单元去处理。

    def __getitem__(self, index):"""获取单个样本数据Args:index (int): 样本索引Returns:tuple: (预处理后的图像数据, 对应的标签)"""# 打开图片文件image = Image.open(self.image[index])# 获取对应标签label = self.label[index]# 应用数据预处理if self.transform:image = self.transform(image)# 将标签转换为torch张量label = torch.from_numpy(np.array(label))return image, label

测试dataloader

import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloader import FlowerDataSet  # 假设你的数据集类在dataloader.py中def denormalize(image_tensor):"""将归一化的图像张量转换为可显示的格式"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = image_tensor.numpy().transpose((1, 2, 0))  # 转换维度顺序image = std * image + mean  # 反归一化image = np.clip(image, 0, 1)  # 限制像素值范围return imagedef test_dataloader():# 定义数据预处理data_transforms = {'train': transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),transforms.CenterCrop(64),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 检查文件路径是否存在print("[1/5] 检查文件路径...")required_files = {'train_txt': './flower_data/train.txt','val_txt': './flower_data/val.txt','train_dir': './flower_data/train_filelist','val_dir': './flower_data/val_filelist'}for name, path in required_files.items():if not os.path.exists(path):print(f"❌ 文件/目录不存在: {path}")returnprint(f"✅ {name}: {path} 存在")# 初始化数据集print("\n[2/5] 加载数据集...")try:train_dataset = FlowerDataSet(root_dir=required_files['train_dir'],ann_file=required_files['train_txt'],transform=data_transforms['train'])val_dataset = FlowerDataSet(root_dir=required_files['val_dir'],ann_file=required_files['val_txt'],transform=data_transforms['valid'])print("✅ 数据集加载成功")except Exception as e:print(f"❌ 数据集加载失败: {str(e)}")return# 打印数据集信息print("\n[3/5] 数据集统计:")print(f"训练集样本数: {len(train_dataset)}")print(f"验证集样本数: {len(val_dataset)}")# 检查单个样本print("\n[4/5] 检查单个样本:")sample_idx = 0try:img, label = train_dataset[sample_idx]print(f"图像张量形状: {img.shape} (应接近 torch.Size([3, 64, 64]))")print(f"标签类型: {type(label)} (应为 torch.Tensor)")print(f"标签值: {label.item()} (应为整数)")except Exception as e:print(f"❌ 样本检查失败: {str(e)}")# 可视化样本print("\n[5/5] 可视化训练集样本...")try:plt.figure(figsize=(8, 8))img_show = denormalize(img)plt.imshow(img_show)plt.title(f"Label: {label.item()}")plt.axis('off')plt.show()except Exception as e:print(f"❌ 可视化失败: {str(e)}")# 检查DataLoaderprint("\n[附加] 检查DataLoader:")train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:print(f"\n{name} DataLoader测试:")try:batch = next(iter(loader))images, labels = batchprint(f"批次图像形状: {images.shape} (应接近 [batch, 3, 64, 64])")print(f"批次标签示例: {labels[:5].numpy()}")print(f"像素值范围: [{images.min():.3f}, {images.max():.3f}]")except Exception as e:print(f"❌ {name} DataLoader错误: {str(e)}")if __name__ == '__main__':test_dataloader()

在测试代码中,分别测试了文件路径,dataset是否正常创建,dataset样本数量,dataset样本格式,dataset数据可视化,dataloader数据样式。

在打印日志的时候需要注意,dataset和dataloader里面的变量都是张量形式的,所以需要转换成python标量再打印。比如从dataset里取出的标签label是一个一维张量,需要通过label.item()进行转换。

 在遍历的时候为了简化代码,将两个dataloader放在同一个循环语句中处理,并且通过增加name变量来区分两个dataloader。

for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/42990.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

业之峰与宏图智能战略携手,开启家装数字化新篇章

3月8日,业之峰装饰集团董事长张钧携高管团队与宏图智能董事长庭治宏及核心团队,在业之峰总部隆重举行了战略合作签约仪式,标志着双方将携手探索业之峰的数字化转型之路,共同推动家装行业的变革与发展。 近年来,家装行业…

区块链赋能,为木材货场 “智” 造未来

区块链赋能,为木材货场 “智” 造未来 在当今数字化浪潮席卷的时代,软件开发公司不断探索创新,为各行业带来高效、智能的解决方案。今天,让我们聚焦于一家软件开发公司的杰出成果 —— 区块链木材货场服务平台,深入了…

Suricata 检测日志中的时间戳不正确

参考连接 Incorrect Timestamp in Suricata Detection Logs - Help - Suricata 问题现象: 使用 Suricata 时遇到一个问题,即检测日志 (eve.json) 中的 and 字段间歇性地显示 2106 年。这似乎偶尔发生,并影响其中一个…

【第34节】windows原理:PE文件的导出表和导入表

目录 一、导出表 1.1 导出表概述 1.2 说明与使用 二、导入表 2.1 导入表概述 2.2 说明与使用 一、导出表 1.1 导出表概述 (1)导出行为和导出表用途:PE文件能把自身的函数、变量或者类,提供给其他PE文件使用,这…

【计算机网络】深入解析TCP/IP参考模型:从四层架构到数据封装,全面对比OSI

TCP/IP参考模型 导读一、历史背景二、分层结构2.1 网络接口层(Network Interface Layer)2.2 网络层(Internet Layer)2.3 传输层(Transport Layer)2.4 应用层(Application Layer) 三、…

项目实战-角色列表

抄上一次写过的代码: import React, { useState, useEffect } from "react"; import axios from axios; import { Button, Table, Modal } from antd; import { BarsOutlined, DeleteOutlined, ExclamationCircleOutlined } from ant-design/icons;const…

LeetCode1两数之和

**思路:**懒得写了,如代码所示 /*** Note: The returned array must be malloced, assume caller calls free().*/ struct hashTable {int key;//存值int val;//存索引UT_hash_handle hh; }; int* twoSum(int* nums, int numsSize, int target, int* re…

去噪算法大比拼

目录 效果图: 实现代码: 密集抖动 pip install pykalman 效果图: 实现代码: import numpy as np import cv2 import matplotlib.pyplot as plt from scipy.ndimage import gaussian_filter1d from scipy.signal import butter, filtfilt, savgol_filter from pykalma…

STM32_HAL开发环境搭建【Keil(MDK-ARM)、STM32F1xx_DFP、 ST-Link、STM32CubeMX】

安装Keil(MDK-ARM)【集成开发环境IDE】 我们会在Keil(MDK-ARM)上去编写代码、编译代码、烧写代码、调试代码。 Keil(MDK-ARM)的安装方法: 教学视频的第02分03秒开始看。 安装过程中请修改一下下面两个路径,避免占用C盘空间。 Core就是Keil(MDK-ARM)的…

深入理解MySQL聚集索引与非聚集索引

在数据库管理系统中,索引是提升查询性能的关键。MySQL支持多种类型的索引,其中最基础也是最重要的两种是聚集索引和非聚集索引。本文将深入探讨这两种索引的区别,并通过实例、UML图以及Java代码示例来帮助您更好地理解和应用它们。 一、概念…

【leetcode】拆解与整合:分治并归的算法逻辑

前言 🌟🌟本期讲解关于力扣的几篇题解的详细介绍~~~ 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 🔥 你的点赞就是小编不断更新的最大动力 🎆那么废话不…

wx162基于springboot+vue+uniapp的在线办公小程序

开发语言:Java框架:springbootuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包&#…

陈宛汮签约2025火凤凰风赏大典全球形象大使

原标题:陈宛汮签约2025火凤凰风赏大典全球形象大使 共工新闻社香港3月29日电 陈宛汮,华语原创女歌手。“星宝在闪耀”公益活动联合发起人,自闭症儿童康复推广大使。代表作:《荣耀火凤凰》《爱在醉千年》。 从2025年1月1日至2025年12月31日&a…

【深度学习入门_机器学习理论】极致梯度提升原理(XGBoost)

XGBoost(eXtreme Gradient Boosting)是一种高效、灵活且广泛应用的机器学习算法,属于梯度提升决策树(Gradient Boosting Decision Tree, GBDT) 的优化实现。它在分类、回归、排序等结构化/表格数据的预测任务中表现尤为…

Oracle初识:登录方法、导入dmp文件

目录 一、登录方法 以sys系统管理员的身份登录 ,无需账户和密码 以账户密码的用户身份登录 二、导入dmp文件 方法一:PLSQL导入dmp文件 一、登录方法 Oracle的登录方法有两种。 以sys系统管理员的身份登录 ,无需账户和密码 sqlplus / a…

STM32F103_LL库+寄存器学习笔记01 - 梳理CubeMX生成的LL库最小的裸机系统框架

《STM32 - 在机器人领域,LL库相比HAL优势明显》在机器人、自动化设备领域使用MCU开发项目,必须用LL库。 本系列笔记记录使用LL库的开发过程,首先通过CubeMX生成LL库代码,梳理LL库源码。通过学习LL库源码,弄清楚寄存器的…

Vue3当中el-tree树形控件使用

tree悬停tooltip效果 文本过长超出展示省略号 如果文本超出悬停显示tooltip效果 反之不显示 这里直接控制固定宽度限制 试了监听宽度没效果<template><el-treeshow-checkbox:check-strictly"true":data"data"node-key"id":props"…

最大数字(java)(DFS实现)

1.最大数字 - 蓝桥云课 因为N最大是10 的17次方&#xff0c; 所以可以利用字符串来处理输入的数字的每一位 并且是从高到低依次处理的 然后通过函数charAt(i)来获取第i位的字符 再减去‘0’就可以将字符转化为整型了 假设每一位数字都是x 然后通过两种操作 加或者减来操…

04 单目标定实战示例

看文本文,您将获得以下技能: 1:使用opencv进行相机单目标定实战 2:标定结果参数含义和数值分析 3:Python绘制各标定板姿态,查看图像采集多样性 4:如果相机画幅旋转90,标定输入参数该如何设置? 5:图像尺寸缩放,标定结果输出有何影响? 6:单目标定结果应用类别…

手机销售终端MPR+LTC项目项目总体方案P183(183页PPT)(文末有下载方式)

资料解读&#xff1a;手机销售终端 MPRLTC 项目项目总体方案 详细资料请看本解读文章的最后内容。在当今竞争激烈的市场环境下&#xff0c;企业的销售模式和流程对于其发展起着至关重要的作用。华为终端正处于销售模式转型的关键时期&#xff0c;波士顿 - 华为销售终端 MPRLTC …