【PyTorch入门】使用PyTorch构建一个简单的图像分类模型

本次分享一个简单的使用PyTorch进行图像分类模型搭建的小案例,让大家对PyTorch的流程有一个认知。

1. 导入必要的库

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import datasets, models, transforms

解释:

  • torch:用于构建深度学习模型的核心库。
  • torch.nn:提供神经网络相关的模块,如层、损失函数等。
  • torchvision:提供与计算机视觉相关的工具,尤其是常用数据集和预训练模型。
  • numpy:用于处理数组和进行数值计算。
  • matplotlib.pyplot:用于图像显示和绘图。
  • torch.autograd.Variable:用于在自动求导时跟踪张量
  • torch.nn.functional:包含神经网络常用的函数,如激活函数等。
  • torch.utils.data:数据加载工具,用于高效读取数据。

2. 设备设置和数据预处理

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"设备已设置为:{device}")transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])print("数据预处理已完成:将图像转换为Tensor并进行标准化。")

解释:

  • 使用 transforms.Compose() 将多个变换组合在一起。
  • transforms.ToTensor():将PIL图像或NumPy数组转换为PyTorch张量,并且自动将像素值从 [0, 255] 归一化到 [0, 1]。
  • transforms.Normalize(mean, std):标准化图像,使其每个通道的均值为0,标准差为1。这里的均值和标准差是 (0.5, 0.5, 0.5),即将每个通道的像素值从 [0, 1] 映射到 [-1, 1]。

3. 加载数据集

# 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=4)testset = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=4)# 定义类名
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 输出数据集信息
print("CIFAR-10 数据集加载完成。")
print(f"训练集样本数: {len(trainset)}")
print(f"测试集样本数: {len(testset)}")# 显示几个样例图像
def imshow(img):img = img / 2 + 0.5  # 反标准化npimg = img.numpy()  # 转为NumPy格式plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 转换维度以适应imshow显示plt.show()# 获取训练数据中的一个batch
dataiter = iter(trainloader)
images, labels = next(dataiter)# 输出真实标签
print('真实标签: ', ' '.join([f'{classes[labels[j]]:5s}' for j in range(4)]))# 显示图像
imshow(torchvision.utils.make_grid(images))

输出:

Files already downloaded and verified
CIFAR-10 数据集加载完成。
训练集样本数: 50000
测试集样本数: 10000
真实标签:  frog  truck truck deer 

在这里插入图片描述
可以看到图像有青蛙,卡车和鹿。

代码详解:

  • 数据集加载与预处理:
    我们首先使用 torchvision.datasets.CIFAR10 加载 CIFAR-10 数据集,并应用了图像的标准化和转换(ToTensor)。
    数据集分为训练集和测试集,分别使用 trainloader 和 testloader 来加载。
  • 显示样例图像:
    我们定义了一个 imshow 函数,用于显示图像。在 imshow 函数中,图像会先进行反标准化操作(img / 2 + 0.5),然后将图像转换为 NumPy 数组,并调整维度以适应 matplotlib 的显示格式。
  • 输出真实标签:
    classes 中定义了 CIFAR-10 数据集的各个类标签。对于我们展示的每个样本,打印出它们的真实标签。
  • 展示图像:
    imshow 函数会展示一个 batch 的图像,torchvision.utils.make_grid 会将该 batch 中的图像拼接成一张大图进行展示。
    我们还输出了该 batch 中每个图像的真实标签。

4. 定义神经网络架构

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xprint("神经网络结构已定义:卷积层和全连接层。")

解释:

  • 定义一个简单的卷积神经网络 Net,继承自 nn.Module。
  • init 方法定义了网络的层:
  1. conv1 和 conv2 是卷积层,conv1 输入通道为3(RGB图像),输出通道为6,卷积核大小为5x5,conv2 的输入通道为6,输出通道为16。
  2. pool 是最大池化层,大小为2x2。
  3. fc1, fc2, fc3 是全连接层。
  • forward 方法定义了数据流动的顺序:
  1. 先通过卷积层和激活函数 ReLU,再经过池化层。
  2. 展平输出为全连接层的输入。
  3. 通过全连接层输出结果。

5. 初始化网络、定义损失函数和优化器

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(net)
print("损失函数和优化器已定义:交叉熵损失和SGD优化器。")

输出:

Net((conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)
损失函数和优化器已定义:交叉熵损失和SGD优化器。
  • 初始化 Net 类实例 net。
  • criterion = nn.CrossEntropyLoss():交叉熵损失函数,适用于多分类问题。
  • optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9):随机梯度下降优化器,学习率设置为0.001,动量为0.9

6.开始训练

nums_epoch = 2
print(f"开始训练,共训练 {nums_epoch} 轮。")for epoch in range(nums_epoch):_loss = 0.0for i, (inputs, labels) in enumerate(trainloader, 0):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()_loss += loss.item()if i % 3000 == 2999:  # 每3000个batch输出一次损失print(f"[{epoch + 1}, {i + 1}] 损失: {(_loss / 3000):.3f}")_loss = 0.0print("训练完成。")

输出:

开始训练,共训练 2 轮。
[1, 3000] 损失: 2.138
[1, 6000] 损失: 1.740
[1, 9000] 损失: 1.582
[1, 12000] 损失: 1.501
[2, 3000] 损失: 1.428
[2, 6000] 损失: 1.381
[2, 9000] 损失: 1.333
[2, 12000] 损失: 1.301
训练完成。

解释:

  • 训练模型2个epoch(nums_epoch = 2)。
  • 遍历 trainloader 中的每个批次:
  • 将输入和标签数据传送到GPU或CPU。
  • 使用网络对输入进行前向传播,得到输出。
  • 计算损失函数,进行反向传播,更新参数。
  • 每3000个batch输出一次损失信息。

7. 显示图像并输出预测结果

def imshow(img):img = img / 2 + 0.5  # 反标准化npimg = img.numpy()  # 转为NumPy格式plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 将图像维度调整为 (height, width, channels)plt.show()dataiter = iter(testloader)
images, labels = next(dataiter)  # 使用next()获取数据
print("获取一个batch的测试图像和标签。图像形状:", images.shape)
imshow(torchvision.utils.make_grid(images))  # 显示图像
print('图像真实分类: ', ' '.join([f'{classes[labels[j]]:5s}' for j in range(4)]))outputs = net(images.to(device))
_, predicted = torch.max(outputs, 1)print('图像预测分类: ', ' '.join([f'{classes[predicted[j]]:5s}' for j in range(4)]))

输出:

获取一个batch的测试图像和标签。图像形状: torch.Size([4, 3, 32, 32])
图像真实分类:  cat   dog   cat   bird 
图像预测分类:  dog   dog   dog   dog 

在这里插入图片描述

解释:

  • 定义了一个 imshow() 函数来显示图像。
  • 使用 next(dataiter) 从 testloader 中获取一个batch的数据。
  • 输出该batch的图像形状(images.shape)以及图像本身。
  • 使用训练好的模型 net 对图像进行预测,并输出预测的分类标签。

8. 计算测试集准确率

correct, total = 0, 0
with torch.no_grad():for images, labels in testloader:images, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (labels == predicted).sum().item()accuracy = 100 * correct / total
print(f"测试集准确率: {accuracy:.2f}%")

使用 torch.no_grad() 禁止计算梯度,提高推理时的效率

本次分享就结束了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/505468.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【SpringAOP】Spring AOP 底层逻辑:切点表达式与原理简明阐述

前言 🌟🌟本期讲解关于spring aop的切面表达式和自身实现原理介绍~~~ 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 🔥 你的点赞就是小编不断更新的最大动力 &am…

智慧公厕大数据驱动下的公共卫生管理与优化

在快速发展的城市化进程中,公共卫生问题日益凸显,成为城市管理的重要议题。智慧公厕,作为公共卫生设施的一次革命性创新,正借助物联网技术的东风,引领公共卫生进入一个全新的生态时代。本文将深入探讨智慧公厕如何利用…

Git:Cherry-Pick 的使用场景及使用流程

前面我们说了 Git合并、解决冲突、强行回退等解决方案 >> 点击查看 这里再说一下 Cherry-Pick功能,Cherry-Pick不是merge,只是把部分功能代码Cherry-Pick到远程的目标分支 git cherry-pick功能简介: git cherry-pick 是用来从一个分…

Windows 安装 Docker 和 Docker Compose

🚀 作者主页: 有来技术 🔥 开源项目: youlai-mall ︱vue3-element-admin︱youlai-boot︱vue-uniapp-template 🌺 仓库主页: GitCode︱ Gitee ︱ Github 💖 欢迎点赞 👍 收藏 ⭐评论 …

前端 图片上鼠标画矩形框,标注文字,任意删除

效果: 页面描述: 对给定的几张图片,每张能用鼠标在图上画框,标注相关文字,框的颜色和文字内容能自定义改变,能删除任意画过的框。 实现思路: 1、对给定的这几张图片,用分页器绑定…

Elasticsearch—索引库操作(增删查改)

Elasticsearch中Index就相当于MySQL中的数据库表 Mapping映射就类似表的结构。 因此我们想要向Elasticsearch中存储数据,必须先创建Index和Mapping 1. Mapping映射属性 Mapping是对索引库中文档的约束,常见的Mapping属性包括: type:字段数据类…

在Jmeter中跨线程组传递变量(token)--设置全局变量

参考资料: Jmeter跨线程组传递参数(token)_jmeter获取token传递给下一个线程组详解-CSDN博客 最近工作中遇到一个问题,就是如何跨线程组传递变量,比如token,后来找到一些资料解决了该问题,目前有两种方式都可以解决,我…

【C++】揭开C++类与对象的神秘面纱(首卷)(类的基础操作详解、实例化艺术及this指针的深究)

文章目录 一、类的定义1.类定义格式2.类访问限定符3.类域 二、类的实例化1.实例化概念2.对象的大小 三、隐藏的this指针与相关练习1.this指针的引入与介绍练习1练习2练习3 一、类的定义 1.类定义格式 在讲解类的作用之前,我们来看看类是如何定义的,在C中…

前端JavaScript中some方法的运用

一.前言 在我们的日常工作中,有时候仅仅需要找到某个数组中的值,就可以返还结果的话,笔者建议就可以使用some方法,这比遍历整个数组高效一些。 二.应用 首先,看官方定义:JavaScri…

安装vue脚手架出现的一系列问题

安装vue脚手架出现的一系列问题 前言使用 npm 安装 vue/cli2.权限问题及解决方法一:可以使用管理员权限进行安装。方法二:更改npm全局安装路径 前言 由于已有较长时间未进行 vue 项目开发,今日着手准备开发一个新的 vue 项目时,在…

基于Python实现的通用小规模搜索引擎

基于Python实现的通用小规模搜索引擎 1.项目简介 1.1背景 《信息内容安全》网络信息内容获取技术课程项目设计 一个至少能支持10个以上网站的爬虫程序,且支持增量式数据采集;并至少采集10000个实际网页;针对采集回来的网页内容, 能够实现网页文本的分…

鸿蒙面试 2025-01-10

写了鉴权工具,你在项目中申请了那些权限?(常用权限) 位置权限 : ohos.permission.LOCATION_IN_BACKGROUND:允许应用在后台访问位置信息。 ohos.permission.LOCATION:允许应用访问精确的位置信息…

【硬件测试】基于FPGA的BPSK+帧同步系统开发与硬件片内测试,包含高斯信道,误码统计,可设置SNR

目录 1.硬件片内测试效果 2.算法涉及理论知识概要 2.1 bpsk 2.2 帧同步 3.Verilog核心程序 4.开发板使用说明和如何移植不同的开发板 5.完整算法代码文件获得 1.硬件片内测试效果 本文是之前写的文章 《基于FPGA的BPSK帧同步系统verilog开发,包含testbench,高斯信道,误…

MySQL 视图 存储过程与存储函数

第十四章_视图、第十五章 _存储过程与存储函数 1.常见的数据库对象 1. 表(Table) 用于存储结构化数据的基本对象,由行(记录)和列(字段)组成。 2. 视图(View) 基于一…

Chrome_60.0.3112.113_x64 单文件版 下载

单文件,免安装,直接用~ Google Chrome, 免費下載. Google Chrome 60.0.3112.113: Chrome 是 Google 開發的網路瀏覽器。它的特點是速度快,功能多。 下载地址: https://blog.s3.sh.cn/thread-150-1-1.htmlhttps://blog.s3.sh.cn/thread-150-1-1.html

CTFshow—文件包含

Web78-81 Web78 这题是最基础的文件包含,直接?fileflag.php是不行的,不知道为啥,直接用下面我们之前在命令执行讲过的payload即可。 ?filephp://filter/readconvert.base64-encode/resourceflag.php Web79 这题是过滤了php,…

python学opencv|读取图像(二十九)使用cv2.getRotationMatrix2D()函数旋转缩放图像

【1】引言 前序已经学习了如何平移图像,相关文章链接为: python学opencv|读取图像(二十七)使用cv2.warpAffine()函数平移图像-CSDN博客 在此基础上,我们尝试旋转图像的同时缩放图像。 【2】…

24下半年软考「单独划线」合格标准已公布!

2024年下半年计算机技术与软件专业技术资格考试单独划线地区合格标准已公布! 其中初级和中级单独划线地区合格标准各科目均为39分,高级各科目为40分,符合单独划线地区的同学可以去申请证书了。 一、证书效力 在单独划线地区报名参加相关职业…

Linux第一课:c语言 学习记录day06

四、数组 冒泡排序 两两比较,第 j 个和 j1 个比较 int a[5] {5, 4, 3, 2, 1}; 第一轮:i 0 n:n个数,比较 n-1-i 次 4 5 3 2 1 // 第一次比较 j 0 4 3 5 2 1 // 第二次比较 j 1 4 3 2 5 1 // 第三次比较 j 2 4 3 2 1 5 // …

前端用json-server来Mock后端返回的数据处理

<html><body><div class"login-container"><h2>登录</h2><div class"login-form"><div class"form-group"><input type"text" id"username" placeholder"请输入用户名&q…