《深度学习》——pytorch框架及项目

文章目录

  • pytorch
    • 特点
    • 基本概念
  • 项目
    • 项目实现
      • 导入所需库
      • 下载训练数据和测试数据
      • 对训练和测试样本进行分批次
      • 展示手写图片
      • 判断pytorch是否支持GPU
      • 定义神经网络模型
      • 定义训练函数
      • 定义测试函数
      • 创建交叉熵损失函数和优化器
      • 通过多轮训练降低损失值得到最终结果
      • 注意

pytorch

PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队开发。它在学术界和工业界都得到了广泛的应用,下面从多个方面详细介绍:

特点

  • 动态计算图:与 TensorFlow 的静态计算图不同,PyTorch 使用动态计算图。这意味着在运行时可以动态地改变计算图的结构,使得代码的编写和调试更加直观和灵活。例如,在训练循环中可以根据不同的条件来改变计算流程。
  • Python 优先:PyTorch 深度集成 Python,代码风格简洁易懂,易于上手。开发者可以利用 Python 丰富的库和工具进行数据处理、可视化等操作。
  • 强大的 GPU 支持:PyTorch 能够充分利用 NVIDIA GPU 的并行计算能力,通过简单的代码就可以将张量和模型转移到 GPU 上进行加速计算,大大提高了训练和推理的速度。
  • 丰富的工具和库:提供了许多高级工具和库,如 Torchvision(用于计算机视觉任务)、Torchaudio(用于音频处理任务)等,方便开发者快速搭建和训练模型。

基本概念

  • 一、张量(Tensor):类似于 NumPy 的多维数组,但可以在 GPU 上运行以加速计算。例如,创建一个简单的张量

    import torch# 创建一个2x3的随机张量
    x = torch.rand(2, 3)
    print(x)
    

    在这里插入图片描述

  • 二、自动求导(Autograd):PyTorch 的自动求导机制可以自动计算张量的梯度,这对于训练神经网络非常重要。在定义张量时,只需要设置requires_grad=True,PyTorch 就会跟踪所有与之相关的操作,并在需要时计算梯度。

    import torch
    # 创建一个需要计算梯度的张量
    x = torch.tensor([2.0], requires_grad=True)
    y = x**2
    # 计算梯度
    y.backward()
    print(x.grad)  # 输出导数 2x,即 4
    
  • 三、模块(Module):torch.nn.Module是所有神经网络模块的基类。通过继承Module类,可以方便地定义自己的神经网络模型。

import torch
import torch.nn as nn# 定义一个简单的全连接神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 创建模型实例
model = SimpleNet()
  • 四、优化器(Optimizer)
    PyTorch 提供了多种优化器,如 SGD、Adam、RMSprop 等,用于更新模型的参数。优化器根据计算得到的梯度来调整模型的参数,以最小化损失函数。

项目

下面我们使用BP神经网络来实现手写数字识别项目,此项目数据集来自MNIST 数据集由美国国家标准与技术研究所(NIST)整理而成,包含手写数字的图像,主要用于数字识别的训练和测试。该数据集被分为两部分:训练集和测试集。训练集包含 60,000 张图像,用于模型的学习和训练;测试集包含 10,000 张图像,用于评估训练好的模型在未见过的数据上的性能。

  • 图像格式:数据集中的图像是灰度图像,即每个像素只有一个值表示其亮度,取值范围通常为 0(黑色)到 255(白色)。
  • 图像尺寸:每张图像的尺寸为 28x28 像素,总共有 784 个像素点。
  • 标签信息:每个图像都有一个对应的标签,标签是 0 到 9 之间的整数,表示图像中手写数字的值

项目实现

导入所需库

import torch
from torch import nn #导入神经网络模块
from torch.utils.data import DataLoader # 数据包管理工具,打包数据
from torchvision import datasets # 封装了很对与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型的数据转换成tensor张量

下载训练数据和测试数据

'''下载训练数据集(包含训练集图片+标签)'''
training_data = datasets.MNIST( # 跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击root='data', # 表示下载的手写数字 到哪个路径。60000train=True, # 读取下载后的数据中的数据集download=True, # 如果你之前已经下载过了,就不用再下载了transform=ToTensor(), # 张量,图片是不能直接传入神经网络模型# 对于pytorch库能够识别的数据一般是tensor张量
)'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),# Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如pytorch,TensorFlow)
)# numpy数组只能在cpu上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。
print(len(training_data))
print(len(test_data))

在这里插入图片描述

训练样本和测试样本的数量
在这里插入图片描述

对训练和测试样本进行分批次

# 创建训练数据的 DataLoader 对象
# DataLoader 是 PyTorch 中用于批量加载数据的实用工具类,它可以帮助我们更高效地处理大规模数据集。
train_dataloader = DataLoader(training_data, batch_size=64)  # 建议用2的指数当作一个包的数量
test_dataloader = DataLoader(test_data, batch_size=64)

展示手写图片

'''展示手写体图片,把训练数据集中的59000张图片展示一下'''from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = training_data[i+59000] # 提取第59000张图片figure.add_subplot(3,3,i+1) # 图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis('off') # plt.show(I) # 显示矢量plt.imshow(img.squeeze(),cmap='gray') # plt.imshow()将numpy数组data中的数据显示为图像,并在图形窗口显示a = img.squeeze() # img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1则张量不会改变。
plt.show()

在这里插入图片描述

判断pytorch是否支持GPU

'''判断是否支持GPU'''
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')

如果支持GPU输出cuda,map系统则输出map,使用CPU则输出CPU
在这里插入图片描述

定义神经网络模型

'''定义神经网络 类的继承这种方式'''
class NeuralNetwork(nn.Module): # 通过调用类的形式来使用神经网络,神经网络模型nn.moduledef __init__(self): # self类自己本身super().__init__() # 继承的父类初始化self.flatten = nn.Flatten() # 展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28*28,128)self.hidden2 = nn.Linear(128,256)self.out = nn.Linear(256,10)def forward(self,x): # 向前传播,数据的流向x = self.flatten(x) # 图像展开x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x) # 激活函数x = self.out(x)return x
model = NeuralNetwork().to(device)
print(model)

定义训练函数

def train(dataloader,model,loss_fn,optimizer):model.train() # 告诉模型,要开始训练,模型中w进行随机化操作,已经更新w,在训练过程中w会被修改# pytorch提供两种方式来切换训练和测试的模式,分别是:model.train()和model.eval()# 一般用法:在训练之前写model.train(),在测试时写model.eval()batch_size_num = 1for x,y in dataloader: # 其中batch为每一个数据的编号x,y=x.to(device),y.to(device) # 将训练数据和标签传入gpupred = model.forward(x) # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred,y) # 通过交叉熵损失函数计算损失值loss# Backpropagation 进来个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad() #梯度值清零loss.backward() # 反向传播计算每一个参数的梯度值woptimizer.step() # 根据梯度更新网络w参数loss_value = loss.item() # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 100 == 0:print(f'loss:{loss_value:7f}  [number:{batch_size_num}]' )batch_size_num += 1

定义测试函数

def test(dataloader,model,loss_fn):size = len(dataloader.dataset) # 10000num_batches = len(dataloader) # 打包的数据model.eval() # 测试,w就不能再更新test_loss,correct = 0,0with torch.no_grad(): # 一个上下文管理器,关闭梯度计算。当你确定不会调用Tensor.backward()的时候。这可以减少计算内存for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()#test_loss是会自动累加每一个批次的损失值correct  +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)#dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /=num_batches#能来衡量模型测试的好坏。correct /= size#平均的正确率print(f'Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')

创建交叉熵损失函数和优化器

loss_fn = nn.CrossEntropyLoss()#创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
#一会改成adam优化器
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)#创建一个优化器,SGD为随机梯度下降算法
#params:要训练的参数,一般我们传入的都是model.parameters()
# #lr:learning_rate学习率,也就是步长。

通过多轮训练降低损失值得到最终结果

epochs = 20
for t in range(epochs):print(f'epoch{t+1}\n--------------------')train(train_dataloader,model, loss_fn, optimizer)
print('Done!')
test(test_dataloader,model, loss_fn)

在这里插入图片描述

可通过结果看出在使用ADam优化器步长为0.001时,训练20轮得到的正确率为97.5%,损失值为0.119。

注意

  • 可以通过修改优化器来提高准确率
  • 对于不同的神经网络要使用不同的激活函数,对于sigmoid函数来说隐藏层过多时会产生梯度消失,因为sigmoid函数的偏导在0~0.25之间随着反向传播的进行,梯度会不断累乘。即使初始梯度较大,但经过多层的累乘后,梯度值会迅速变小,趋近于 0,从而导致梯度消失。因此可用ReLU、tanh、P-ReLU、R-ReLU、maxout等来代替sigmoid函数。ReLU函数偏导为1,不会产生梯度消失问题。
  • 当梯度在传递过程中不断增大,变得非常大时,就会导致梯度爆炸现象。此时,模型参数会因为梯度值过大而发生大幅度的更新,使得模型无法收敛,甚至可能导致数值溢出,使得训练过程崩溃。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/15503.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【批量获取图片信息】批量获取图片尺寸、海拔、分辨率、GPS经纬度、面积、位深度、等图片属性里的详细信息,提取出来后导出表格,基于WPF的详细解决方案

摄影工作室通常会有大量的图片素材,在进行图片整理和分类时,需要知道每张图片的尺寸、分辨率、GPS 经纬度(如果拍摄时记录了)等信息,以便更好地管理图片资源,比如根据图片尺寸和分辨率决定哪些图片适合用于…

windows生成SSL的PFX格式证书

生成crt证书: 安装openssl winget install -e --id FireDaemon.OpenSSL 生成cert openssl req -x509 -newkey rsa:2048 -keyout private.key -out certificate.crt -days 365 -nodes -subj "/CN=localhost" 转换pfx openssl pkcs12 -export -out certificate.pfx…

UnityShader学习笔记——高级纹理

——内容源自唐老狮的shader课程 目录 1.立方体纹理 1.1.概念 1.2.用处 1.3.如何采样 1.4.优缺点 2.天空盒 2.1.概念 2.2.优点 2.3.设置 3.动态生成立方体纹理 3.1.原因 3.2.实现 3.3.代码运行中生成 4.反射 4.1.原理 4.2.补充知识 5.折射 5.1.原理 5.2.菲涅…

IBM服务器刀箱Blade安装Hyper-V Server 2019 操作系统

案例:刀箱某一blade,例如 blade 5 安装 Hyper-V Server 2019 操作系统(安装进硬盘) 刀箱USB插入安装系统U盘,登录192.168... IBM BlandeCenter Restart Blande 5,如果Restart 没反应,那就 Power Off Blade 然后再 Power On 重启后进入BIOS界面设置usb存储为开机启动项 …

C++20新特性

作者:billy 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 前言 C20 是 C 标准中的一个重要版本,引入了许多新特性和改进,包括模块(Modules)、协程…

WPS如何接入DeepSeek(通过JS宏调用)

WPS如何接入DeepSeek 一、文本扩写二、校对三、翻译 本文介绍如何通过 WPS JS宏调用 DeepSeek 大模型,实现自动化文本扩写、校对和翻译等功能。 一、文本扩写 1、随便打开一个word文档,点击工具栏“工具”。 2、点击“开发工具”。 3、点击“查看代码”…

【SQL server】关于SQL server彻底的卸载删除。

1.未彻底卸载删除SQL Server会出现的问题 如果没有彻底删除之前的SQL server,就可能会出现这个 当要安装新的实例的时候因为之前安装过sql server没有删除干净而导致下图问题,说实例名已经存在。 2.首先要先关闭服务 “开始R”可以快速进入运行&#…

对话框补充以及事件处理机制 (2025.2.10)

作业 1> 将鼠标事件和键盘事件相关代码重新实现一遍 2> 将文本编辑器功能完善 主函数main.cpp #include "widget.h"#include <QApplication>int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 头…

企业级Mysql实战

Mysql企业级sql编写实战 1 一对多&#xff0c;列表展示最新记录字段1.1 场景1.2 需求1.3 实现1.3.1 表及数据准备1.3.2 Sql编写 2 区间统计&#xff08;if/case when&#xff09;2.1 场景2.2 需求2.3 实现2.2.1 表及数据准备2.3.2 sql编写 3 多类别分组统计&#xff08;竖表转横…

C语言基础第04天:数据的输出和输出

C语言基础:04天笔记 内容提要 回顾C语言数据的输入输出 回顾 运算符 算术运算符 结果:数值 - * / % (正) -(负) -- i和i 相同点:i自身都会增1 不同点:他们运算的最终结果是不同的. i先使用 ,后加1; i先计算,后使用 赋值运算符 结果:赋值后的变量的值 赋值顺序:由右…

DeepSeek训练成本与技术揭秘

引言&#xff1a;在当今人工智能蓬勃发展的时代&#xff0c;DeepSeek 宛如一颗耀眼的新星&#xff0c;突然闯入大众视野&#xff0c;引发了全球范围内的热烈讨论。从其惊人的低成本训练模式&#xff0c;到高性能的模型表现&#xff0c;无一不让业界为之侧目。它打破了传统认知&…

数组与指针1

1. 数组名的理解 1.1 数组名是数组首元素的地址 int arr[10] {1,2,3,4,5,6,7,8,9,10};int *p &arr[0]; 这里我们使用 &arr[0] 的方式拿到了数组第一个元素的地址&#xff0c;但是其实数组名本来就是地址&#xff0c;而且是数组首元素的地址。如下&#xff1a; 1.2…

Axure原型图怎么通过链接共享

一、进入Axure 二、点击共享 三、弹出下面弹框&#xff0c;点击发布就可以了 发布成功后&#xff0c;会展示链接&#xff0c;复制即可共享给他人 四、发布失败可能的原因 Axure未更新&#xff0c;首页菜单栏点击帮助选择Axure更新&#xff0c;完成更新重复以上步骤即可

软件模拟I2C案例(寄存器实现)

引言 在经过前面对I2C基础知识的理解&#xff0c;对支持I2C通讯的EEPROM芯片M24C02的简单介绍以及涉及到的时序操作做了整理。接下来&#xff0c;我们就正式进入该案例的实现环节了。本次案例是基于寄存器开发方式通过软件模拟I2C通讯协议&#xff0c;然后去实现相关的需求。 阅…

脚手架开发【实战教程】prompts + fs-extra

创建项目 新建文件夹 mycli_demo 在文件夹 mycli_demo 内新建文件 package.json {"name": "mycli_demo","version": "1.0.0","bin": {"mycli": "index.js"},"author": "","l…

【大模型】DeepSeek-V3技术报告总结

系列综述&#xff1a; &#x1f49e;目的&#xff1a;本系列是个人整理为了学习DeepSeek相关知识的&#xff0c;整理期间苛求每个知识点&#xff0c;平衡理解简易度与深入程度。 &#x1f970;来源&#xff1a;材料主要源于DeepSeek官方技术报告进行的&#xff0c;每个知识点的…

只需三步!5分钟本地部署deep seek——MAC环境

MAC本地部署deep seek 第一步:下载Ollama第二步:下载deepseek-r1模型第三步&#xff1a;安装谷歌浏览器插件 第一步:下载Ollama 打开此网址&#xff1a;https://ollama.com/&#xff0c;点击下载即可&#xff0c;如果网络比较慢可使用文末百度网盘链接 注&#xff1a;Ollama是…

力扣hot100刷题第一天

哈希 1. 两数之和 题目 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。 你可以假设每种输入只会对应一个答案&#xff0c;并且你不能使用两次相同的元素。你可以按任意…

Linux(CentOS)安装 Nginx

CentOS版本&#xff1a;CentOS 7 Nginx版本&#xff1a;1.24.0 两种安装方式&#xff1a; 一、通过 yum 安装&#xff0c;最简单&#xff0c;一键安装&#xff0c;全程无忧。 二、通过编译源码包安装&#xff0c;需具备配置相关操作。 最后附&#xff1a;设置 Nginx 服务开…

项目6:基于大数据校园一卡通数据分析和可视化

1、项目简介 本项目是基于大数据的清华校园卡数据分析系统&#xff0c;通过Hadoop&#xff0c;spark等技术处理校园卡交易、卡号和商户信息数据。系统实现消费类别、男女消费差异、学院消费排行和年级对比等分析&#xff0c;并通过Web后端和可视化前端展示结果。项目运行便捷&…