什么是Pytorch?

在这里插入图片描述

当谈及深度学习框架时,PyTorch 是当今备受欢迎的选择之一。作为一个开源的机器学习库,PyTorch 为研究人员和开发者们提供了一个强大的工具来构建、训练以及部署各种深度学习模型。你可能会问,PyTorch 是什么,它有什么特点,以及如何使用它呢?

什么是 PyTorch?

PyTorch 是一个基于 Python 的机器学习库,专注于强大的张量计算(tensor computation)和动态计算图(dynamic computation graph)。与其他框架相比,它的一个显著特点就是动态计算图,这意味着你可以在运行时定义和修改计算图,从而更灵活地构建复杂的模型。PyTorch 由 Facebook 的人工智能研究小组开发,已经得到了广泛的认可和采用。

PyTorch 的特点

  1. 动态计算图: PyTorch 的动态计算图使得模型构建和调试变得更加直观。你可以像编写 Python 代码一样编写神经网络结构,而不需要事先定义静态图。

  2. 张量操作: PyTorch 提供了丰富的张量操作功能,它们类似于 NumPy 数组,但是可以在 GPU 上运行以加速计算,适用于大规模的数据处理和深度学习任务。

  3. 自动求导: PyTorch 自动处理了求导过程,无需手动计算梯度。这使得训练模型变得更加方便和高效。

  4. 模块化设计: PyTorch 的模块化设计使得构建复杂的神经网络变得简单。你可以通过组合不同的模块来创建自己的模型。

如何使用 PyTorch?

让我们通过一个简单的示例来看看如何使用 PyTorch 来构建一个基本的神经网络:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络类
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建神经网络实例、损失函数和优化器
net = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)# 加载数据并进行训练
for epoch in range(5):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss}")
print("Finished Training")

分析环节:

可能会有很多小伙伴不明白,我会进行整个代码的详细分析,逐行解释每个部分的作用和功能。

import torch
import torch.nn as nn
import torch.optim as optim

这部分代码导入了PyTorch库的必要模块,包括torchtorch.nn以及torch.optimtorch是PyTorch的核心模块,提供了张量等基本数据结构和操作;torch.nn提供了神经网络相关的类和函数;torch.optim提供了各种优化器,用于更新神经网络的参数。

# 定义一个简单的神经网络类
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x

这部分定义了一个简单的神经网络类SimpleNN,该类继承自nn.Module,是PyTorch中自定义神经网络的一种标准做法。网络有两个全连接层(线性层):fc1fc2forward方法定义了前向传播过程,首先通过fc1进行线性变换,然后使用ReLU激活函数,最后通过fc2输出。

# 创建神经网络实例、损失函数和优化器
net = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)

在这部分,我们实例化了刚刚定义的SimpleNN类,创建了一个神经网络netnn.CrossEntropyLoss()是交叉熵损失函数,适用于多类别分类问题。optim.SGD是随机梯度下降优化器,用于更新网络的权重和偏置。

# 加载数据并进行训练
for epoch in range(5):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss}")
print("Finished Training")

这部分是训练过程的主体。我们使用一个外层循环进行多次训练迭代(5次),每次迭代中,我们遍历训练数据集,计算并更新网络的参数。

  • for epoch in range(5)::外层循环迭代5次,表示5个训练轮次。

  • running_loss = 0.0:用于记录每个训练轮次的累计损失。

  • for i, data in enumerate(trainloader, 0)::遍历训练数据集。enumerate函数用于同时获取数据的索引i和数据本身data

  • inputs, labels = data:将数据拆分为输入和标签。

  • optimizer.zero_grad():清零梯度,准备进行反向传播。

  • outputs = net(inputs):将输入数据输入神经网络,得到输出。

  • loss = criterion(outputs, labels):计算输出和真实标签之间的损失。

  • loss.backward():进行反向传播,计算梯度。

  • optimizer.step():使用优化器更新网络的参数。

  • running_loss += loss.item():累计损失。

  • print(f"Epoch {epoch+1}, Loss: {running_loss}"):打印每个轮次的训练损失。

  • print("Finished Training"):训练完成后打印提示。

整个代码实现了对一个简单的神经网络的训练过程,通过反向传播更新网络参数,使得模型能够逐渐拟合训练数据,从而实现分类任务。

案例分析

我们要说个典型案例:使用 PyTorch 进行图像分类。通过构建神经网络模型、加载数据集、定义损失函数和优化器,可以训练出一个能够识别不同类别的图像的分类器。

我们将创建了一个卷积神经网络(CNN)模型,加载CIFAR-10数据集,通过定义损失函数和优化器,进行模型的训练。这个模型可以用来对CIFAR-10数据集中的图像进行分类,识别不同的物体类别。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 步骤 2:加载和预处理数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)# 使用 torchvision 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 创建一个 DataLoader,用于批量加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)# 步骤 3:定义神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)  # 输入通道数为3,输出通道数为6,卷积核大小为5x5self.pool = nn.MaxPool2d(2, 2)  # 最大池化,窗口大小为2x2self.conv2 = nn.Conv2d(6, 16, 5)  # 输入通道数为6,输出通道数为16,卷积核大小为5x5self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 全连接层,输入维度为16x5x5,输出维度为120self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84self.fc3 = nn.Linear(84, 10)  # 全连接层,输入维度为84,输出维度为10(类别数)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 使用ReLU激活函数x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)  # 将张量展平,以适应全连接层x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 创建神经网络实例
net = Net()# 步骤 4:定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于分类问题
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # 使用随机梯度下降进行优化# 步骤 5:训练神经网络模型
for epoch in range(2):  # 进行两个 epoch 的训练running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()  # 梯度归零,防止累加outputs = net(inputs)  # 前向传播,得到预测结果loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新参数running_loss += loss.item()  # 累加损失if i % 2000 == 1999:print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")  # 打印损失running_loss = 0.0
print("Finished Training")  # 训练完成

案例通过加载 CIFAR-10 数据集,构建一个简单的卷积神经网络,定义损失函数和优化器,并进行模型训练。训练过程中,我们采用了随机梯度下降(SGD)优化算法,使用交叉熵损失函数来优化分类任务。每个 epoch 的训练过程会在控制台输出损失值,以便我们监控训练的进展情况。

总结而言,PyTorch 是一个功能强大且易用的深度学习框架,适用于各种机器学习和深度学习任务。它的动态计算图、张量操作和自动求导等特性使得模型的构建和训练变得更加高效和灵活。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/103863.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么需要单元测试?

为什么需要单元测试? 从产品角度而言,常规的功能测试、系统测试都是站在产品局部或全局功能进行测试,能够很好地与用户的需要相结合,但是缺乏了对产品研发细节(特别是代码细节的理解)。 从测试人员角度而言…

Mysql简短又易懂

MySql 连接池:的两个参数 最大连接数:可以同时发起的最大连接数 单次最大数据报文:接受数据报文的最大长度 数据库如何存储数据 存储引擎: InnoDB:通过执行器对内存和磁盘的数据进行写入和读出 优化SQL语句innoDB会把需要写入或者更新的数…

Java如何调用接口API并返回数据(两种方法)

Java如何调用接口API并返回数据(两种方法) java处理请求接口后返回的json数据-直接处理json字符串 处理思路: 将返回的数据接收到一个String对象中(有时候需要自己选择性的取舍接收) 再将string转换为JSONObject对象 …

Go 语言在 Windows 上的安装及配置

1. Go语言的下载 Golang官网:All releases - The Go Programming Language Golang中文网:Go下载 - Go语言中文网 - Golang中文社区 两个网站打开的内容只有语言不同而已,网站上清晰的标注了不同操作系统需要对应安装哪个版本,其中…

港联证券|燃气板块午后走高,美能能源涨停,水发燃气大幅拉升

燃气板块21日午后快速拉升,到发稿,美能动力涨停,水发燃气涨超7%,蓝天燃气涨超5%,贵州燃气涨逾4%。 消息面上,受澳大利亚LNG工厂罢工忧虑影响,欧洲基准天然气价格一度大涨18%。 有报导称&#x…

npm报错:xxx packages are looking for funding run `npm fund` for details(解决办法)

报错信息:30 packages are looking for funding run npm fund for details 报错原因:这里是开发者捐赠支持的提示,打开一个github的链接之后,会显示是否需要打赏捐赠的信息。 解决方案:这个打赏是资源的,因…

YOLOV8 win10部署笔记

文章目录 1. 背景2. 部署过程2.1 快速安装 1. 背景 看了B站许多up主的视频,感觉YOLOV8各方面都很优秀,作为新手对它的期待很大,于是想实际跑跑看,边实践,边学习,记录过程。 本篇主要是博主在windows平台上…

控制Unity发布的PC包的窗体

大家好,我是阿赵。   用Unity发布PC包接入某些渠道时,有时候会收到一些特殊的需求,比如控制窗口最大化(比如某些情况强制显示窗体)、最小化(比如老板键)、强制规定窗体置顶等。虽然我一直认为这些需求都是流氓软件行为,但作为一…

[管理与领导-43]:IT基层管理者 - 个人管理 - 管理中从角色定位迈步

前言: 管理者的“四位” : ‣ 定位——在什么位置做什么事情; ‣ 到位——全力以赴把事情做好; ‣ 不越位——不要把别人的工作做了; ‣ 补位——同事临时“缺位” ,及时补位,提升效率&…

Linux:shell脚本:基础使用(6)《正则表达式-awk工具》

简介 awk是行处理器: 相比较屏幕处理的优点,在处理庞大文件时不会出现内存溢出或是处理缓慢的问题,通常用来格式化文本信息 awk处理过程: 依次对每一行进行处理,然后输出 1)awk命令会逐行读取文件的内容进行处理 2)a…

clickhouse-压测

一、数据集准备 数据集可以使用官网数据集,也可以用ssb-dbgen来准备 1.准备数据 这里最后生成表的数据行数为60亿行,数据量为300G左右 git clone https://github.com/vadimtk/ssb-dbgen.git cd ssb-dbgen/ make1.1 生成数据 # -s 指生成多少G的数据…

在线转换器有哪些优势?在线Word转PDF操作分享

我们如果想要将两者不同格式文件进行格式转换,就需要下载安装转换器。如果出门带的设备没有安装转换软件客户端,就无法使用,会比较麻烦。现在有了在线转换工具,只需要打开相应的网页就可使用,那么在线Word转PDF的操作是…

matlab 点云精配准(1)——point to point ICP(点到点的ICP)

目录 一、算法原理参考文献二、代码实现三、结果展示四、参考链接本文由CSDN点云侠原创,爬虫自重。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫。 一、算法原理 参考文献 [1] BESL P J,MCKAY N D.A method for registration of 3-Dshapes[J].IEEE Tran…

vue3——递归组件的使用

该文章是在学习 小满vue3 课程的随堂记录示例均采用 <script setup>&#xff0c;且包含 typescript 的基础用法 一、使用场景 递归组件 的使用场景&#xff0c;如 无限级的菜单 &#xff0c;接下来就用菜单的例子来学习 二、具体使用 先把菜单的基础内容写出来再说 父…

interview1-DB篇

需要项目经验可自行上Gitee寻找项目资源 一、Redis篇 1、缓存 缓存的要点可分为穿透、击穿、雪崩&#xff0c;双写一致、持久化&#xff0c;数据过期、淘汰策略。 &#xff08;1&#xff09;穿透、击穿、雪崩 1.缓存穿透 查询一个不存在的数据&#xff0c;mysql查询不到数据…

网络面试题(172.22.141.231/26,该IP位于哪个网段? 该网段拥有多少可用IP地址?广播地址是多少?)

此题面试中常被问到&#xff0c;一定要会172.22.141.231/26&#xff0c;该IP位于哪个网段&#xff1f; 该网段拥有多少可用IP地址&#xff1f;广播地址是多少&#xff1f; 解题思路&#xff1a; 网络地址&#xff1a;172.22.141.192 10101100.00010110.10001101.11000000 广播…

javascript常用的东西

JavaScript 是一门强大的编程语言&#xff0c;用于为网页添加交互性和动态性。也可以锻炼人们的逻辑思维&#xff0c;是一个非常好的东西。 一、变量和数据类型&#xff1a; 变量&#xff1a; 变量是用于存储数据值的容器。在 JavaScript 中&#xff0c;你可以使用 var、let…

git分支

一、引言 分支的命名规范以及管理方式对项目的版本发布至关重要&#xff0c;为了解决实际开发过程中版本发布时代码管理混乱、冲突等比较头疼的问题&#xff0c;我们将在文中阐述如何更好的管理代码分支。 二、总览&#xfeff; 从上图可以看到主要包含下面几个分支&#xff…

真伪定时器

首先观察一下下面两组代码区别在哪里&#xff1f; 第一组代码 setInterval(() > {// 1.5s 的同步逻辑 }, 1000);第二组代码 function fn() {setTimeout(() > {// 1.5s 的同步逻辑fn();}, 1000); }fn();两组代码都有定时功能&#xff0c;看起来也都是每隔1s执行一次任务…

ubuntu20搭建环境使用的一下指令

1.更新源 sudo vim etc/apt/sources.listdeb http://mirrors.aliyun.com/ubuntu/ xenial main deb-src http://mirrors.aliyun.com/ubuntu/ xenial maindeb http://mirrors.aliyun.com/ubuntu/ xenial-updates main deb-src http://mirrors.aliyun.com/ubuntu/ xenial-updates…