孙怡带你深度学习(2)--PyTorch框架认识

文章目录

  • PyTorch框架认识
    • 1. Tensor张量
      • 定义与特性
      • 创建方式
    • 2. 下载数据集
      • 下载测试
      • 展现下载内容
    • 3. 创建DataLoader(数据加载器)
    • 4. 选择处理器
    • 5. 神经网络模型
      • 构建模型
    • 6. 训练数据
      • 训练集数据
      • 测试集数据
    • 7. 提高模型学习率
  • 总结

PyTorch框架认识

PyTorch是一个由Facebook人工智能研究院(FAIR)在2016年发布的开源深度学习框架,专为GPU加速的深度神经网络(DNN)编程而设计。它以其简洁、灵活和符合Python风格的特点,在科研和工业生产中得到了广泛应用。

1. Tensor张量

在PyTorch中,张量(Tensor)是核心数据结构,它是一个多维数组,用于存储和变换数据。张量类似于Numpy中的数组,但具有更丰富的功能和灵活性,特别是在支持GPU加速方面。

定义与特性

  • 多维数组:张量可以看作是一个n维数组,其中n可以是任意正整数。它可以是标量(零维数组)、向量(一维数组)、矩阵(二维数组)或具有更高维度的数组。
  • 数据类型统一:张量中的元素具有相同的数据类型,这有助于在GPU上进行高效的并行计算。
  • 支持GPU加速:PyTorch中的张量可以存储在CPU或GPU上,通过将张量转移到GPU上,可以利用GPU的强大计算能力来加速深度学习模型的训练和推理过程。

创建方式

  • 直接使用torch.tensor():根据提供的Python列表或Numpy数组创建张量。
  • 下载数据集时:transform=ToTensor()直接将数据转化为Tensor张量类型。

2. 下载数据集

在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?

导入datasets模块

from torchvision import datasets #封装了很多与图像相关的模型,数据集

以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。

下载测试

我们来下载MNIST数据集

from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度

在这里插入图片描述

下载完成之后可在project栏查看。

展现下载内容

我们来查看部分图片(第59000张到第59009张):

"""-----展现手写字图片-----"""
# tensor -->numpy  矩阵类型数据
from matplotlib import pyplot as plt
figure = plt.figure() # 创建一个新的图形
for i in range(9):img,label = training_data[i+59000] #提取第59000张图片figure.add_subplot(3,3,i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off")# 关闭当前轴的坐标轴plt.imshow(img.squeeze(),cmap="gray")a = img.squeeze()# squeeze()从张量img中去掉维度为1的。如果该维度不为1则张量不会改变
plt.show()

图片信息获取时,得到的张量数据类型是这样的:

在这里插入图片描述

我们通过squeeze()函数,去掉维度为1的。这样我们就可以得到图片的高宽大小,将它展现出来:

在这里插入图片描述

3. 创建DataLoader(数据加载器)

在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。

DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。

import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64

4. 选择处理器

我们知道,电脑中的处理器有CPU和GPU两种,CPU擅长执行复杂的指令和逻辑操作,而GPU则擅长处理大量并行计算任务。

所以,在可以的条件下,我们选择使用GPU处理器来学习深度学习,因为计算量比较大:

"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
----------------
Using cuda device

5. 神经网络模型

通过调用类的形式来使用神经网络,神经网络的模型:nn.module。

构建模型

我们在构建时,得明确神经网络模型的结构:输入层–隐藏层–输出层,而在每一个隐藏层进入下一层时,都会有一个激活函数计算,所以我们也按着这个架构层次定义函数:

class NeuralNetwork(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型:nn.moduledef __init__(self): # self类自己本身super().__init__() #继承的父类初始化self.flatten = nn.Flatten()# 输入层,展开一个对象flattenself.hidden1 = nn.Linear(28*28,256)# 隐藏层,第1个参数:有多少神经元传入进来;第二个参数,有多少数据传出去self.hidden2 = nn.Linear(256,128)self.hidden3 = nn.Linear(128,64)self.hidden4 = nn.Linear(64,32)self.out = nn.Linear(32,10)#输出层,输出必须与类别数量相同,输入必须是上一层的个数def forward(self,x): #前向传播(该名字不要轻易改),告诉它数据的流向x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x) #激活函数x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device) #将刚刚创建的模型传入到GPU
print(model)
-----------------------
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(hidden1): Linear(in_features=784, out_features=256, bias=True)(hidden2): Linear(in_features=256, out_features=128, bias=True)(hidden3): Linear(in_features=128, out_features=64, bias=True)(hidden4): Linear(in_features=64, out_features=32, bias=True)(out): Linear(in_features=32, out_features=10, bias=True)
)

6. 训练数据

训练数据时,需要注意的参数:

  • optimizer优化器

在PyTorch中,创建Optimizer的主要作用是管理并更新模型中可学习参数(即权重和偏置)的值,以便最小化某个损失函数(loss function)。

  1. 梯度清零:在每次迭代开始时,Optimizer会调用**.zero_grad()**方法来清除之前累积的梯度,这是因为在PyTorch中,梯度是累加的,如果不清零,则下一次的梯度计算会包含前一次的梯度,导致错误的更新。
  2. 梯度计算:在模型进行前向传播(forward pass)和损失计算之后,Optimizer并不直接参与梯度的计算。梯度的计算是通过调用损失函数的**.backward()**方法完成的,该方法会计算损失函数关于模型中所有可学习参数的梯度,并将这些梯度存储在相应的参数对象中。
  3. 参数更新:在梯度计算完成后,Optimizer会调用**.step()**方法来根据计算得到的梯度以及选择的优化算法(如SGD、Adam等)来更新模型的参数。这一步骤是优化过程中最关键的部分,它决定了模型学习的方向和速度。
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
  • loss_fn损失函数

在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax()nn.NLLLoss()(负对数似然损失)在一个单独的类中。

loss_fn = nn.CrossEntropyLoss()

训练集数据

from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device)  # 将数据和标签发送到指定设备  pred = model.forward(x)  # 前向传播  loss = loss_fn(pred,y)  # 计算损失  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 反向传播  optimizer.step()  # 更新模型参数  loss_value = loss.item()  # 获取损失值if batch_size_num %200 == 0:  # 每200次迭代打印一次损失  print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
------------------------
loss:1.039446 [number:200]
loss:0.754774 [number:400]
loss:0.553383 [number:600]
loss:0.573400 [number:800]

测试集数据

def test(dataloader,model,loss_fn):size = len(dataloader.dataset) # 获取测试集的总大小。num_batches = len(dataloader) # 计算数据加载器中的批次数量。model.eval() # 将模型设置为评估模式。test_loss,correct = 0,0 # 初始化总损失和正确预测的数量。with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")---------------------
Test result: Accuracy:89.96%,Avg loss:0.36642977581092506

我们可以看到,这个模型的正确率不是特别的高,那么接下来我们来提高模型的学习率。

7. 提高模型学习率

遍历了指定的训练周期(epochs)数,并在每个周期中调用 train 函数来训练模型。

"""-----调整学习率-----"""
epochs = 10
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)
---------------
仅展示优化后的结果:
Test result: Accuracy:97.33000000000001%,Avg loss:0.10455594740913303

总结

本篇介绍了:

  1. PyTorch的框架
  2. 数据类型张量,数据集的获取
  3. 如何构建对应神经网络的模型
  4. 如何优化算法:一、修改optimizer优化器的算法;二、遍历合适的训练周期(epochs)数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/426259.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UE5安卓项目打包安装

Android studio安装 参考:https://docs.unrealengine.com/5.2/zh-CN/how-to-set-up-android-sdk-and-ndk-for-your-unreal-engine-development-environment/ 打开android studio的官网:Download Android Studio & App Tools - Android Developers …

深度学习-生成式检索-论文速读-2024-09-14

深度学习-生成式检索-论文速读-2024-09-14 前言: 生成式检索(Generative Retrieval, GR)是一种结合了生成模型和检索系统的人工智能技术方法。这种方法在处理信息检索任务时,不仅依赖于已有数据的检索,还能生成新的、…

解锁SQL无限可能 | 基于SQL实现的一种时序数据的波峰个数检测算法

目录 0 算法原理 1 数据准备 2 问题分析 3 小结 数字化建设通关指南专栏原价99,现在活动价39.9,按照阶梯式增长,直到恢复原价 0 算法原理 波峰识别算法 序列数据是按照时间进行采集,其中400个点一个周期,一条数据…

【零散技术】Odoo17通过Controller下载PDF

序言:时间是我们最宝贵的财富,珍惜手上的每个时分 Odoo作为一款开源ERP,拥有极佳的拓展性,Odoo的Controller框架也让它具备了作为微信小程序后端的能力,那么就存在 需要通过小程序来下载PDF的业务情况。 目录 1.功能代码 1.1 manifest 设置 …

Tensorflow—第五讲卷积神经网络

本讲概述 卷积实际上就是特征提取。本讲我们先了解学习卷积神经网络基础知识,再一步步地学习搭建卷积神经网络,最后会运用卷积神经网络对cifar10 数据集分类。在本讲的最后附上几个经典卷积神经网络:LeNet、AlexNet、VGGNet、InceptionNet和…

在Linux中安装FFmpeg

在Linux中安装FFmpeg有两种方法。 安装FFmpeg(方法一) 第一步,下载FFmpeg。 登录地址:John Van Sickle - FFmpeg Static Builds下载安装包ffmpeg-git-amd64-static.tar.xz。然后使用WinSCP将安装包上传到文件夹/usr/local/src中…

vue2基础系列教程之插槽slot你不得不知道的知识点及面试高频问题

vue2中对插槽的介绍,花了大量的章节篇幅,可想而知,它在框架中的重要性。 slot及slot-scope自 2.6.0 起被废弃。新推荐的语法请查阅v-slot,就语法我们这里就一笔带过,主要学习新的语法 你不能不知道的slot知识点 插槽的作用&#…

C++知识要点总结笔记

文章目录 前言一、c基础1.指针和引用指针和引用的区别函数指针 2.数据类型整型 short int long 和 long long无符号类型强制类型转换怎样判断两个浮点数是否相等? 3.关键字conststaticconst和static的区别define 和 typedef 的区别define 和 inline 的区别define和c…

PostgreSQL(PG)(二十二)

🌻🌻 目录 🌻🌻 一、PostgreSQL 简介1.1、PG 的历史1.2、PG的社区1.2.1 纯社区1.2.2 完善的组织结构1.2.3 开源许可独特性 1.3 、PostgreSQL与MySQL的比较 二、PostgresQL的下载安装2.1、Windows上安装 PostgreSQL2.2、远程 连接 …

Pikachu靶场之XSS

先来点鸡汤,少就是多,慢就是快。 环境搭建 攻击机kali 192.168.146.140 靶机win7 192.168.146.161 下载zip,pikachu - GitCode 把下载好的pikachu-master,拖进win7,用phpstudy打开网站根目录,.....再用…

Unity3D下如何播放RTSP流?

技术背景 在Unity3D中直接播放RTSP(Real Time Streaming Protocol)流并不直接支持,因为Unity的内置多媒体组件(如AudioSource和VideoPlayer)主要设计用于处理本地文件或HTTP流,而不直接支持RTSP。所以&…

【话题讨论】AI时代程序员核心力:技术深耕,跨界学习,软硬兼备

目录 引言 一、AI辅助编程对程序员工作的影响 1.1 AI工具如何提升工作效率 1.2 AI工具的风险 1.3 应对策略 二、程序员应重点发展的核心能力 2.1 核心竞争力 2.2 企业和教育机构的调整 三、人机协作模式下的职业发展规划 3.1 持续学习的重要性 3.2 选择适合自己的…

Python3网络爬虫开发实战(17)爬虫的管理和部署(第一版)

文章目录 一、 Scrapyd 分布式部署1.1 了解 Scrapyd1.2 准备工作1.3 访问 Scrapyd1.4 Scrapyd 的功能1.5 ScrapydAPI 的使用 二、Scrapyd-Client 的使用2.1 准备工作2.2 Scrapyd-Client 的功能2.3 Scrapyd-Client 部署 三、Scrapyd 对接 Docker3.1 准备工作3.2 对接 Docker 四、…

Java Web服务运行一段时间后出现cpu升高导致的性能下降问题排查

背景 有个web服务,运行一段时间后,出现cpu逐渐占用高,服务处理请求整体性能下降问题。 异常情况时, 同时jvm的cpu上涨 最终表现为,处理内部逻辑执行耗时变高。 排查原因 原来服务的jvm启动参数带了 -XX:-TieredCom…

Java项目实战II基于Java+Spring Boot+MySQL的校园社团信息管理系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在当今高校…

马匹行为识别系统源码分享

马匹行为识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

计算机毕业设计 在线新闻聚合平台的设计与实现 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

【计算机毕设-大数据方向】基于Hadoop的智能交通数据分析可视化系统的设计与实现

💗博主介绍:✌全平台粉丝5W,高级大厂开发程序员😃,博客之星、掘金/知乎/华为云/阿里云等平台优质作者。 【源码获取】关注并且私信我 【联系方式】👇👇👇最下边👇👇&…

Redis的存储原理和数据模型

一、Redis是单线程还是多线程呢? 我们通过跑redis的代码,查看运行的程序可以得知,Redis本身其实是个多线程,其中包括redis-server,bio_close_file,bio_aof_fsync,bio_lazy_free,io_t…

Mina protocol - 体验教程

Mina protocol - 体验教程 一、零知识证明( Zero Knowledge Proof )1、零知识证明(ZKP)的基本流程工作流程: 2、zkApp 的优势:3、zkApp 每个方法的编译过程: 二、搭建第一个zkapp先决条件1、下载或者更新 zkApp CLI​2…