41目标检测数据集
import os
import pandas as pd
import torch
import torchvision
import matplotlib.pylab as plt
from d2l import torch as d2l# 数据集下载链接
# http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip# 读取数据集
#@save
def read_data_bananas(is_train=True):"""读取香蕉检测数据集中的图像和标签"""data_dir = '../data/banana-detection/'csv_fname = os.path.join(data_dir, 'bananas_train' if is_trainelse 'bananas_val', 'label.csv')csv_data = pd.read_csv(csv_fname)# 将 img_name 列设置为索引,以便后续操作中根据图片名称索引标签。csv_data = csv_data.set_index('img_name')images, targets = [], [] # images 用于存储图像,targets 用于存储标签。for img_name, target in csv_data.iterrows():images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))# 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),# 其中所有图像都具有相同的香蕉类(索引为0)targets.append(list(target))# 将 targets 列表转换为 PyTorch 张量,并增加一个维度(通过 unsqueeze(1))。# 对标签进行归一化处理(除以 256)。return images, torch.tensor(targets).unsqueeze(1) / 256 # 增加维度以匹配其他张量的形状# 图像的小批量的形状为(批量大小、通道数、高度、宽度)# 标签的小批量的形状为(批量大小,m,5),其中m是数据集的任何图像中边界框可能出现的最大数量。#@save
class BananasDataset(torch.utils.data.Dataset):"""一个用于加载香蕉检测数据集的自定义数据集"""def __init__(self, is_train):self.features, self.labels = read_data_bananas(is_train)print('read ' + str(len(self.features)) + (f' training examples' ifis_train else f' validation examples'))def __getitem__(self, idx):return (self.features[idx].float(), self.labels[idx])def __len__(self):return len(self.features)#@save
def load_data_bananas(batch_size):"""加载香蕉检测数据集"""train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)return train_iter, val_iterbatch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))# print(batch[0].shape, batch[1].shape)
# torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])# 效果演示
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
# batch[0] 是包含图像数据的张量,形状为 (batch_size, channels, height, width)
# batch[0][0:10] 选择前 10 个图像。
# .permute(0, 2, 3, 1) 将张量的维度重新排列变为 (batch_size, height, width, channels)
# / 255 将像素值归一化到 [0, 1] 之间
# 图像的像素值通常在0到255之间。如果不进行归一化,像素值直接使用原始范围。
# 图像库在显示图像时,需要将像素值映射到一个合理的范围内。
# 在0到1范围内时,显示库可以更好地处理和展示这些图像。axes = d2l.show_images(imgs, 2, 5, scale=2)
# d2l.show_images 是一个用于显示多张图像的函数。
# imgs 是预处理后的图像张量。
# 2, 5 指定了图像将被显示为 2 行 5 列的网格。
# scale=2 指定了图像的缩放比例。# batch[1]是包含图像标签的张量torch.Size([32, 1, 5])
for ax, label in zip(axes, batch[1][0:10]): d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])# d2l.show_bboxes 是一个用于在图像上绘制边界框的函数。# ax 是当前图像的坐标轴。# label[0][1:5] 提取标签中的边界框坐标(标签格式为 [class, x_min, y_min, x_max, y_max])。# * edge_size 将边界框坐标缩放到图像的实际尺寸。# colors=['w'] 指定边界框的颜色为白色。
plt.show()
运行结果: