1、torchvision及其数据集的介绍
1.1 torchvision介绍
torchvision 是 PyTorch 的一个官方库,专门用于计算机视觉任务。它提供了以下核心功能:
- 预训练模型:如 ResNet、VGG、EfficientNet 等。
- 数据集:内置常用视觉数据集(如 MNIST、CIFAR、ImageNet)。
- 数据转换工具:用于图像预处理和数据增强(如裁剪、旋转、归一化)。
- 工具函数:如张量转换、可视化等。
torchvision库的官方文档:torchvision库的官方文档
1.2 torchvision 中的数据集
torchvision.datasets 模块包含多个经典数据集,可直接下载并使用(torchvision.dataset模块数据集文档)。以下是常用数据集及其特点:
- MNIST
描述:手写数字(0-9)灰度图像。
样本:60,000 训练 + 10,000 测试。
图像大小:28×28 像素。
用途:入门级分类任务。
加载示例:
from torchvision import datasets
train_data = datasets.MNIST(root='./data', train=True, download=True)
MNIST文档
- CIFAR-10 / CIFAR-100
描述:
CIFAR-10:10 类彩色图像(飞机、猫等)。
CIFAR-100:100 细分类别。
样本:50,000 训练 + 10,000 测试。
图像大小:32×32 像素。
用途:小图像分类基准。
加载示例:
cifar10 = datasets.CIFAR10(root='./data', train=True, download=True)
CIFAR-10文档
CIFAR-100文档
- ImageNet
描述:大规模图像分类数据集(1,000 类)。
样本:约 120 万训练 + 50,000 验证。
图像大小:可变(需调整至相同尺寸)。
注意:需手动下载数据(非直接通过 torchvision)。
加载示例:
imagenet = datasets.ImageNet(root='./data', split='train')
ImageNet文档
- FashionMNIST
描述:时尚单品(衣服、鞋等)灰度图像,替代 MNIST。
样本:同 MNIST。
用途:比 MNIST 更具挑战性的分类。
FashionMNIST文档
- COCO
描述:目标检测、分割、字幕生成数据集。
样本:超 20 万标注图像。
加载示例:
coco = datasets.CocoDetection(root='./data', annFile='annotations.json')
COCO文档
- 其他数据集
KITTI:自动驾驶场景(检测、深度估计)。
SVHN:街景门牌号数字识别。
CelebA:人脸属性识别(20 万名人图像)。
通用数据加载方法
所有数据集可通过 torch.utils.data.DataLoader 批量加载:
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_data = datasets.MNIST(root='./data', train=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
关键点总结
- 内置数据集:直接通过 datasets.XYZ 调用,自动下载(部分需手动准备)。
- 数据增强:配合 transforms 模块实现标准化、翻转等操作。
- 灵活扩展:支持自定义数据集(继承 torch.utils.data.Dataset)。
通过 torchvision,用户可以快速实验模型,无需重复实现数据预处理流程。
二、使用实例
2.1 CIFAR-10 dataset介绍
2.1.1 数据集内容介绍
Torchvision 中的 CIFAR-10 是一个经典的计算机视觉数据集,广泛用于图像分类任务的基准测试。
-
数据集概述
名称:CIFAR-10(Canadian Institute For Advanced Research)
用途:图像分类(10个类别)
数据内容:
60,000 张 32x32 像素的彩色图像(RGB)。
分为 50,000 张训练图像和 10,000 张测试图像。
均匀分布在 10 个类别中,每个类别有 6,000 张图像。 -
类别标签
10 个类别分别为:
飞机(airplane)、汽车(automobile)、鸟(bird)、猫(cat)、鹿(deer)、
狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)。 -
在 Torchvision 中的使用方法
加载数据集
import torchvision
import torchvision.transforms as transforms# 定义数据预处理(标准化、Tensor转换等)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 下载并加载训练集和测试集
train_dataset = torchvision.datasets.CIFAR10(root='./data', # 数据保存路径train=True, # 加载训练集download=True, # 如果本地不存在则下载transform=transform # 应用预处理
)test_dataset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform
)
数据加载器(DataLoader)
from torch.utils.data import DataLoadertrain_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
- 数据集特点
- 图像尺寸小:32x32 像素,适合快速实验和算法验证。
- 低分辨率挑战:由于尺寸小,细节较少,对模型的特征提取能力要求较高。
- 标准化数据:通常需要归一化到 [-1, 1] 或 [0, 1] 范围。
- 常见用途
- 基准测试:用于对比不同模型的分类性能(如ResNet、VGG等)。
- 教学示例:深度学习入门常用的“Hello World”级数据集(类似MNIST的彩色版)。
- 扩展功能
- 数据增强:通过 transforms 动态增强训练数据:
- 数据增强实验:可通过 torchvision.transforms 添加旋转、翻转等操作。
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
- 与CIFAR-100的关系:Torchvision 还提供 CIFAR-100 数据集,包含 100 个细粒度类别,每类 600 张图像。
- 注意事项
- 下载问题:若国内下载慢,可手动下载并解压到 root 指定路径(官方链接)。
- 内存占用:全部加载后约占 ~200MB 内存,适合大多数设备。
CIFAR-10数据集介绍:CIFAR-10数据集
2.1.2 数据集使用语法
基本语法
torchvision.datasets.CIFAR10(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
参数说明
- root (str 或 pathlib.Path) – 数据集的根目录,其中目录 cifar-10-batches-py 存在,如果 download 设置为 True,则将保存到该 cifar-10-batches-py 中。
- train (bool, optional) – 如果为 True,则从训练集创建数据集,否则从测试集创建。
- transform (callable, optional) (转换) – 接收 PIL 图像并返回转换后的版本的函数/转换。例如,transforms.RandomCrop
- target_transform (callable, optional) – – 接收目标并对其进行转换的函数/转换。
- download (bool, optional) – 如果为 true,则从 Internet 下载数据集并将其放入根目录中。如果数据集已下载,则不会再次下载。
Special-members:
__getitem__(index: int) → Tuple[Any, Any]
Parameters: 参数
- index (int) – Index
Returns: 结果
- (image, target) where target is index of the target class.
Return type: 返回类型 :
- tuple
2.2 CIFAR-10 dataset使用实例
2.2.1 创建数据集实例
import torchvisiontrain_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)print(test_set[0])
print(test_set.classes)img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
运行结果:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1CC0CC19AF0>, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x1CC0CC19AF0>
3
cat
结果表示索引数据集返回的是一个PIL类型的图像以及一个整数target,该整数表示某个类如cat。在classes属性中存放了一个类列表,target作为列表的下标代表对应的类。
image.show()显示图片:
2.2.2 使用transform操作
复习:Pytorch中的Transforms学习
使用transform操作对对数据集统计进行处理:
import torchvisiondataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=False)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=False)img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
再次运行:
tensor([[[0.6196, 0.6235, 0.6471, ..., 0.5373, 0.4941, 0.4549],[0.5961, 0.5922, 0.6235, ..., 0.5333, 0.4902, 0.4667],[0.5922, 0.5922, 0.6196, ..., 0.5451, 0.5098, 0.4706],...,[0.2667, 0.1647, 0.1216, ..., 0.1490, 0.0510, 0.1569],[0.2392, 0.1922, 0.1373, ..., 0.1020, 0.1137, 0.0784],[0.2118, 0.2196, 0.1765, ..., 0.0941, 0.1333, 0.0824]],[[0.4392, 0.4353, 0.4549, ..., 0.3725, 0.3569, 0.3333],[0.4392, 0.4314, 0.4471, ..., 0.3725, 0.3569, 0.3451],[0.4314, 0.4275, 0.4353, ..., 0.3843, 0.3725, 0.3490],...,[0.4863, 0.3922, 0.3451, ..., 0.3804, 0.2510, 0.3333],[0.4549, 0.4000, 0.3333, ..., 0.3216, 0.3216, 0.2510],[0.4196, 0.4118, 0.3490, ..., 0.3020, 0.3294, 0.2627]],[[0.1922, 0.1843, 0.2000, ..., 0.1412, 0.1412, 0.1294],[0.2000, 0.1569, 0.1765, ..., 0.1216, 0.1255, 0.1333],[0.1843, 0.1294, 0.1412, ..., 0.1333, 0.1333, 0.1294],...,[0.6941, 0.5804, 0.5373, ..., 0.5725, 0.4235, 0.4980],[0.6588, 0.5804, 0.5176, ..., 0.5098, 0.4941, 0.4196],[0.6275, 0.5843, 0.5176, ..., 0.4863, 0.5059, 0.4314]]])
3
cat
可以发现,PIL图像转换为tensor格式了。
2.2.3 使用Tensorboard记录
复习:Pytorch中Tensorboard的学习
使用for循环,将测试集中的前10张张量格式的图片添加到日志中:writer.add_image(“test_set”, img, i)
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=False)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=False)writer = SummaryWriter("runs")
for i in range(10):img, target = test_set[i]writer.add_image("test_set", img, i)writer.close()
运行结束后,在终端执行命令:
tensorboard --logdir=E:\my_pycharm_projects\project1\runs
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
打开网址:
CIFAR-10使用说明:CIFAR使用语法说明