Pytorch Tutorial

本教程将详细展示如何使用PyTorch训练神经网络,并给出完整代码和关键注释(ipython),建议使用CoLab的GPU来编译代码。

目录

  • 环境初始化
  • 数据准备
  • 模型搭建
  • 模型训练
    • 优化器
    • 训练和评估函数
    • 开始训练
  • 可视化

环境初始化

!pip install torchprofile 1>/dev/null 
#torchprofile用于分析PyTorch模型,帮助理解模型的计算复杂度和参数数量等;1>/dev/null表示不显示安装过程中的任何输出信息,为了简洁。#一堆库
import random
from collections import OrderedDict, defaultdictimport numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm#定义随机种子,为了能复现
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

 

数据准备

本教程用 CIFAR-10 作为训练数据集。该数据集总共有60,000张图片。这些图片被分为50,000张训练图片和10,000张测试图片,包含来自10个类的图像,其中每个图像的大小为3x32x32,即大小为32x32像素的3通道彩色图像。

transforms = {"train": Compose([RandomCrop(32, padding=4),RandomHorizontalFlip(),ToTensor(),]),"test": ToTensor(),
}dataset = {}
for split in ["train", "test"]:dataset[split] = CIFAR10(root="data/cifar10",train=(split == "train"),download=True,transform=transforms[split],)#可视化图像:从测试数据集中为每个类别抽取四个样本,并将这些样本的图像和标签可视化。
samples = [[] for _ in range(10)]  #创建一个列表samples,其中包含10个子列表,对应于数据集中的10个类别。每个子列表用于存储该类别的样本图像
for image, label in dataset["test"]:  #遍历测试数据集dataset["test"]的每一个图像和标签对。对于每个样本,检查其标签对应的子列表中的样本数量。如果某个类别的样本数少于4个,就将当前图像添加到该类别对应的子列表中if len(samples[label]) < 4:samples[label].append(image)plt.figure(figsize=(20, 9)) #设置一个适合显示这些图像的图形大小。
for index in range(40):label = index % 10 #01234567890123...每行依次显示每个label的图像image = samples[label][index // 10] #0000000000111...每列依次显示单个label的图像# 图片格式由 CHW 转换到 HWC,为了可视化image = image.permute(1, 2, 0)# 将类索引转换为类名label = dataset["test"].classes[label]# 画图 4 * 10plt.subplot(4, 10, index + 1)plt.imshow(image)plt.title(label)plt.axis("off")
plt.show()

在这里插入图片描述

为了训练神经网络,我们需要批量输入数据。我们创建批处理大小为512的数据加载器 (data loaders):

dataflow = {}
for split in ['train', 'test']:dataflow[split] = DataLoader(dataset[split],batch_size=512,shuffle=(split == 'train'),num_workers=0,pin_memory=True,)for inputs, targets in dataflow["train"]:print("[inputs] dtype: {}, shape: {}".format(inputs.dtype, inputs.shape))print("[targets] dtype: {}, shape: {}".format(targets.dtype, targets.shape))break

在这里插入图片描述

 

模型搭建

我们将使用VGG-11的一个变体(具有更少的下样本和更小的分类器)作为我们的模型。

class VGG(nn.Module):ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'] #网络架构:8个卷积层 + 4个池化层def __init__(self) -> None:super().__init__()layers = []counts = defaultdict(int)def add(name: str, layer: nn.Module) -> None: #用于添加层到layers列表layers.append((f"{name}{counts[name]}", layer))counts[name] += 1in_channels = 3for x in self.ARCH:if x != 'M':# conv-bn-reluadd("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))add("bn", nn.BatchNorm2d(x))add("relu", nn.ReLU(True))in_channels = xelse:# maxpooladd("pool", nn.MaxPool2d(2))self.backbone = nn.Sequential(OrderedDict(layers)) #将层列表转换为有序字典,并通过nn.Sequential创建一个顺序容器self.backbone,这使得输入数据可以顺序通过定义的所有层。self.classifier = nn.Linear(512, 10) #线性层(nn.Linear),用于将卷积网络提取的特征映射到类别标签上。def forward(self, x: torch.Tensor) -> torch.Tensor:# backbone: [N, 3, 32, 32] => [N, 512, 2, 2]x = self.backbone(x)  #特征提取# avgpool: [N, 512, 2, 2] => [N, 512]x = x.mean([2, 3])  #对特征图进行全局平均池化# classifier: [N, 512] => [N, 10]x = self.classifier(x) #通过分类器得到最终的类别预测return x model = VGG().cuda() #创建了VGG类的一个实例,并通过.cuda()方法将模型的所有参数和缓冲区移动到GPU上,以利用GPU加速计算。(假设你的环境支持CUDA)
#详细看一下模型结构
print(model.backbone)

在这里插入图片描述

#详细分析一下模型的参数#计算模型大小:
num_params = 0
for param in model.parameters():if param.requires_grad:num_params += param.numel()
print("#Params:", num_params)
#Params: 9228362#计算模型的计算开销,由multiply–accumulate operations (MACs,乘法累加操作)来衡量
num_macs = profile_macs(model, torch.zeros(1, 3, 32, 32).cuda())
print("#MACs:", num_macs)
#MACs: 606164480#该模型有9.2M个参数,需要606M次乘法和累加操作进行一次推理。

 

模型训练

优化器

optimizer = SGD(model.parameters(),lr=0.4,momentum=0.9,weight_decay=5e-4, #权重衰减(L2正则化),有助于防止模型过拟合
)num_epochs = 20
steps_per_epoch = len(dataflow["train"])# 分段线性学习率调度。学习率随着训练步数的增加先线性增大,达到一定值后再线性减小。
lr_lambda = lambda step: np.interp([step / steps_per_epoch],[0, num_epochs * 0.3, num_epochs],[0, 1, 0]
)[0]# Visualize the learning rate schedule
steps = np.arange(steps_per_epoch * num_epochs)
plt.plot(steps, [lr_lambda(step) * 0.4 for step in steps])
plt.xlabel("Number of Steps")
plt.ylabel("Learning Rate")
plt.grid("on")
plt.show()scheduler = LambdaLR(optimizer, lr_lambda) #应用学习率调度器

在这里插入图片描述
 

训练和评估函数

def train(model: nn.Module,dataflow: DataLoader,criterion: nn.Module,optimizer: Optimizer,scheduler: LambdaLR,
) -> None:model.train() #告诉PyTorch模型现在处于训练模式,这对于某些特定层如Dropout和BatchNorm是必要的,因为它们在训练和评估时的行为不同for inputs, targets in tqdm(dataflow, desc='train', leave=False):# Move the data from CPU to GPUinputs = inputs.cuda()targets = targets.cuda()# 在每次的参数更新前,需要将梯度归零,防止梯度在反向传播时累积optimizer.zero_grad()# Forward inferenceoutputs = model(inputs)loss = criterion(outputs, targets)# Backward propagationloss.backward()# 根据计算出的梯度更新模型参数optimizer.step()# 根据学习率调度器更新学习率scheduler.step()@torch.inference_mode() # 禁用梯度计算
def evaluate(model: nn.Module,dataflow: DataLoader
) -> float:model.eval() ##告诉PyTorch模型现在处于评估模式,关闭Dropout和BatchNorm的特定训练行为。num_samples = 0num_correct = 0for inputs, targets in tqdm(dataflow, desc="eval", leave=False):# Move the data from CPU to GPUinputs = inputs.cuda()targets = targets.cuda()# Inferenceoutputs = model(inputs)# 将模型输出(通常是逻辑值或概率)转换为类别索引outputs = outputs.argmax(dim=1)# Update metricsnum_samples += targets.size(0)num_correct += (outputs == targets).sum()return (num_correct / num_samples * 100).item()

 

开始训练

for epoch_num in tqdm(range(1, num_epochs + 1)):train(model, dataflow["train"], criterion, optimizer, scheduler)metric = evaluate(model, dataflow["test"])print(f"epoch {epoch_num}:", metric)

在这里插入图片描述

可视化

可视化模型的预测,看看模型的真实表现。

plt.figure(figsize=(20, 10))
for index in range(40):image, label = dataset["test"][index+66]# Model inferencemodel.eval()with torch.inference_mode():pred = model(image.unsqueeze(dim=0).cuda())pred = pred.argmax(dim=1)# Convert from CHW to HWC for visualizationimage = image.permute(1, 2, 0)# Convert from class indices to class namespred = dataset["test"].classes[pred]label = dataset["test"].classes[label]# Visualize the imageplt.subplot(4, 10, index + 1)plt.imshow(image)plt.title(f"pred: {pred}" + "\n" + f"label: {label}")plt.axis("off")
plt.show()

在这里插入图片描述

Perfect!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/274606.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux学习:基础开发工具的使用(1)

目录 1. Linux软件包管理器&#xff1a;yum工具1.1 yum是什么&#xff08;软件商城&#xff09;1.2 yum的使用1.3 yum的背景生态 2. 项目开发与集成开发环境3. vim编辑器3.1 vim编辑器的常见模式与模式切换3.3 vim编辑器的使用3.3.1 命令模式下的常见命令&#xff1a;3.3.2 vim…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的石头剪刀布手势识别系统详解(深度学习模型+UI界面代码+训练数据集)

摘要&#xff1a;本篇博客深入探讨了使用深度学习技术开发石头剪刀布手势识别系统的过程&#xff0c;并分享了完整代码。该系统利用先进的YOLOv8、YOLOv7、YOLOv6、YOLOv5算法&#xff0c;并对这几个版本进行性能对比&#xff0c;如mAP、F1 Score等关键指标。文章详细阐述了YOL…

vscode使用npm命令无反应,而终端可以的解决办法

如若你遇到这种情况 使用命令 get-command npm 去下面这个路径把它删掉就可以了

HarmonyOS的功能及场景应用

一、基本介绍 鸿蒙HarmonyOS主要应用的设备包括智慧屏、平板、手表、智能音箱、IoT设备等。具体来说&#xff0c;鸿蒙系统是一款面向全场景(移动办公、运动健康、社交通信、媒体娱乐等)的分布式操作系统&#xff0c;能够支持手机、平板、智能穿戴、智慧屏、车机等多种终端设备…

《量子计算:下一个大风口,还是一个热炒概念?》

引言 量子计算,作为一项颠覆性的技术,一直以来备受关注。它被认为是未来计算领域的一次革命,可能改变我们对计算能力和数据处理的理解。然而,随着技术的不断进步和商业应用的探索,人们开始思考,量子计算到底是一个即将到来的大风口,还是一个被过度炒作的概念? 量子计…

工业物联网平台在水务环保、暖通制冷、电力能源等行业的应用

随着科技的不断发展&#xff0c;工业物联网平台作为连接物理世界与数字世界的桥梁&#xff0c;正逐渐成为推动各行业智能化转型的关键力量。在水务环保、暖通制冷、电力能源等行业&#xff0c;工业物联网平台的应用尤为广泛&#xff0c;对于提升运营效率、降低能耗、优化管理等…

16. C++标准库

C标准库兼容C语言标准函数库&#xff0c;可以在C标准库中直接使用C语言标准函数库文件&#xff0c;同时C标准库增加了自己的源代码文件&#xff0c;新增文件使用C编写&#xff0c;多数代码放在std命名空间中&#xff0c;所以连接C标准库文件后还需要 using namespace std;。 【…

【RabbitMQ】RabbitMQ的交换机

交换机类型 在上文中&#xff0c;都没有交换机&#xff0c;生产者直接发送消息到队列。而一旦引入交换机&#xff0c;消息发送的模式会有很大变化&#xff1a;可以看到&#xff0c;在订阅模型中&#xff0c;多了一个exchange角色&#xff0c;而且过程略有变化&#xff1a; Pub…

【wps】wps与office办公函数储备使用(结合了使用案例 持续更新)

【wps】wps与office办公函数储备使用(结合了使用案例 持续更新) 1、TODAY函数 返回当前电脑系统显示的日期 TODAY函数&#xff1a;表示返回当前电脑系统显示的日期。 公式用法&#xff1a;TODAY() 2、NOW函数 返回当前电脑系统显示的日期和时间 NOW函数&#xff1a;表示返…

Day29:安全开发-JS应用DOM树加密编码库断点调试逆向分析元素属性操作

目录 JS原生开发-DOM树-用户交互 JS导入库开发-编码加密-逆向调试 思维导图 JS知识点&#xff1a; 功能&#xff1a;登录验证&#xff0c;文件操作&#xff0c;SQL操作&#xff0c;云应用接入&#xff0c;框架开发&#xff0c;打包器使用等 技术&#xff1a;原生开发&#x…

GaussDB(DWS)运维利刃:TopSQL工具解析

在生产环境中&#xff0c;难免会面临查询语句出现异常中断、阻塞时间长等突发问题&#xff0c;如果没能及时记录信息&#xff0c;事后就需要投入更多的人力及时间成本进行问题的定位和解决&#xff0c;有时还无法定位到错误出现的地方。在本期《GaussDB(DWS)运维利刃&#xff1…

在 Python 中从键盘读取用户输入

文章目录 如何在 Python 中从键盘读取用户输入input 函数使用input读取键盘输入使用input读取特定类型的数据处理错误从用户输入中读取多个值 getpass 模块使用 PyInputPlus 自动执行用户输入评估总结 如何在 Python 中从键盘读取用户输入 原文《How to Read User Input From t…

小家电显示驱动芯片SM1616特点与相关型号推荐

电饭煲、电磁炉、空调和机顶盒等等小家电通常需要使用显示驱动芯片来控制和驱动显示屏。这些显示驱动芯片的主要功能是将处理器的信号转换成显示屏能够理解的信号&#xff0c;从而显示出相应的文字和图像。 具体来说&#xff0c;电饭煲、电磁炉、空调等家等小家电通常会有一个或…

四桥臂三相逆变器动态电压恢复器(DVR)MATLAB仿真

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 简介 四桥臂三相逆变器 电路 的一般形式如图 1&#xff0c;为 便于分析 &#xff0c;将其等效成图所示的电路 。以直流母线电压Ud的 1&#xff0f;2处为参考点 &#xff0c;逆变器三相和零线相 输 出可等效成…

深度学习:如何面对隐私和安全方面的挑战

深度学习技术的广泛应用推动了人工智能的快速发展&#xff0c;但同时也引发了关于隐私和安全的深层次担忧。如何在保护用户隐私的同时实现高效的模型训练和推理&#xff0c;是深度学习领域亟待解决的问题。差分隐私、联邦学习等技术的出现&#xff0c;为这一挑战提供了可能的解…

pytorch的梯度图与autograd.grad和二阶求导

前向与反向 这里我们从 一次计算 开始比如 zf(x,y) 讨论若我们把任意对于tensor的计算都看为函数&#xff08;如将 a*b&#xff08;数值&#xff09; 看为 mul(a,b)&#xff09;&#xff0c;那么都可以将其看为2个过程&#xff1a;forward-前向&#xff0c;backward-反向在pyto…

3.7号freeRtoS

1. 串口通信 配置串口为异步通信 设置波特率&#xff0c;数据位&#xff0c;校验位&#xff0c;停止位&#xff0c;数据的方向 同步通信 在同步通信中&#xff0c;数据的传输是在发送端和接收端之间通过一个共享的时钟信号进行同步的。这意味着发送端和接收端的时钟需要保持…

进电子厂了,感触颇多...

作者&#xff1a;三哥 个人网站&#xff1a;https://j3code.cn 本文已收录到语雀&#xff1a;https://www.yuque.com/j3code/me-public-note/lpgzm6y2nv9iw8ec 是的&#xff0c;真进电子厂了&#xff0c;但主人公不是我。 虽然我不是主人公&#xff0c;但是我经历的过程是和主…

Qt 实现诈金花的牌面值分析工具

诈金花是很多男人最爱的卡牌游戏 , 每当你拿到三张牌的时候, 生活重新充满了期待和鸟语花香. 那么我们如果判断手中的牌在所有可能出现的牌中占据的百分比位置呢. 这是最终效果: 这是更多的结果: 在此做些简单的说明: 炸弹(有些地方叫豹子) > 同花顺 > 同花 > 顺…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的常见车型识别系统(Python+PySide6界面+训练代码)

摘要&#xff1a;本文深入探讨了如何应用深度学习技术开发一个先进的常见车型识别系统。该系统核心采用最新的YOLOv8算法&#xff0c;并与早期的YOLOv7、YOLOv6、YOLOv5等版本进行性能比较&#xff0c;主要评估指标包括mAP和F1 Score等。详细解析了YOLOv8的工作机制&#xff0c…