pytorch张量基础

好的,为了编写一篇全面且详细的指南,涵盖 PyTorch 中张量的所有知识,并为学习机器学习和深度学习打好基础,我将会提供一个结构化的内容,包括基础知识、进阶知识、实际应用和一些优化技巧。这个文档大纲如下:

  1. 引言
  2. 张量的基础知识
    1. 张量的概念
    2. 张量的属性
    3. 张量的创建
  3. 张量的操作
    1. 基本运算
    2. 索引和切片
    3. 形状变换
  4. 自动微分
    1. 基本概念
    2. 停止梯度传播
  5. 张量的设备管理
    1. 检查和移动张量
    2. CUDA 张量
  6. 高级操作
    1. 张量的视图
    2. 广播机制
    3. 分块和拼接
    4. 张量的复制
  7. 内存优化和管理
    1. 稀疏张量
    2. 内存释放
  8. 应用实例
    1. 线性回归
    2. 神经网络基础
  9. 总结

1. 引言

在机器学习和深度学习中,张量(Tensor)是核心的数据结构。了解和掌握张量的操作是学习 PyTorch 和构建神经网络模型的必要基础。张量可以表示从标量到高维数组的数据结构,它在 PyTorch 的计算图中扮演着基础角色。本指南旨在全面介绍 PyTorch 中张量的相关知识,帮助读者从基础打好深度学习的基础。

2. 张量的基础知识

1. 张量的概念

张量是一个数组的通用化,可以表示标量(0维)、向量(1维)、矩阵(2维)及更高维的数组。通俗来说,张量是一种多维数据结构,其本质上是一个多维数组。

2. 张量的属性

张量有多个重要属性,用来描述其数据和结构:

  • 形状(shape):描述张量的维度结构,例如 (2, 3) 表示一个包含 2 行 3 列的矩阵。
  • 数据类型(dtype):指定张量中元素的类型,例如 torch.float32torch.int64 等。
  • 设备(device):指示张量存储的设备,可以是 CPU 或 GPU。
  • 步幅(stride):步幅表示连续两个元素在各个维度上的步进距离。
import torchtensor = torch.tensor([[1., 2., 3.], [4., 5., 6.]])print(tensor.shape)    # torch.Size([2, 3])
print(tensor.dtype)    # torch.float32
print(tensor.device)   # cpu
print(tensor.stride()) # (3, 1)

3. 张量的创建

可以通过多种方式创建张量,包括从已有数据创建、使用随机数生成和从其他张量创建。

# 从数据创建
scalar = torch.tensor(5.0)          # 标量
vector = torch.tensor([1.0, 2.0, 3.0])  # 向量
matrix = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  # 矩阵# 使用随机数创建
rand_tensor = torch.rand(2, 3)     # 均匀分布
randn_tensor = torch.randn(2, 3)   # 标准正态分布# 从其他张量创建
zeros_tensor = torch.zeros_like(matrix)  # 创建与 matrix 形状相同的全零张量

3. 张量的操作

1. 基本运算

张量支持基本的算术运算,包括加、减、乘、除。

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 加法
c = a + b# 减法
d = a - b# 乘法
e = a * b# 除法
f = a / b# 点积
dot_prod = torch.dot(a, b)  # 32.0# 矩阵乘法
matrix1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
matrix2 = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
matrix_mul = torch.mm(matrix1, matrix2)  # [[19.0, 22.0], [43.0, 50.0]]

2. 索引和切片

张量支持多种索引和切片操作,类似于 NumPy。

tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])# 索引
element = tensor[1, 2]  # 6.0# 切片
subset = tensor[:, 1]  # tensor([2.0, 5.0])

3. 形状变换

在不复制数据的情况下,PyTorch 支持多种形状变换操作。

# 重塑
reshaped = tensor.view(3, 2)  # tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])# 转置
transposed = tensor.t()       # tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]])# 增加或减少维度
unsqueezed = tensor.unsqueeze(0)  # 增加第0维
squeezed = tensor.squeeze()       # 去除所有维度为1的维度

4. 自动微分

PyTorch 提供强大的自动微分功能,称为Autograd。它可以自动计算张量的梯度,适用于优化和训练神经网络。

1. 基本概念

张量可以设置 requires_grad=True 以启用自动微分。计算张量的梯度使用 backward() 方法。

x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x[0] ** 2 + x[1] ** 3
y.backward()
print(x.grad)  # tensor([ 4.0, 27.0])

2. 停止梯度传播

在某些情况下,比如模型评估或推理时,需要停止梯度传播以提高性能并节省内存。

with torch.no_grad():y = x[0] ** 2 + x[1] ** 3# 使用 detach() 方法创建一个新的张量,该张量与原始张量共享数据,但不进行梯度追踪
detached_tensor = x.detach()

5. 张量的设备管理

1. 检查和移动张量

张量可以在 CPU 或 GPU 上进行计算。PyTorch 提供了简单的方法来检查和移动张量到不同的设备。

tensor = torch.tensor([1.0, 2.0, 3.0])# 检查是否有可用的 GPU
if torch.cuda.is_available():tensor = tensor.to('cuda')print(tensor.device)  # cuda:0# 将张量移动回 CPU
tensor = tensor.to('cpu')
print(tensor.device)  # cpu

2. CUDA 张量

使用 CUDA 张量可以显著提高计算速度,特别是在深度学习中。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensor = torch.tensor([1.0, 2.0, 3.0], device=device)

6. 高级操作

1. 张量的视图

视图允许我们在不复制数据的情况下,改变张量的形状。

original_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
view_tensor = original_tensor.view(6)  # tensor([1, 2, 3, 4, 5, 6])# 修改视图
view_tensor[0] = 10
print(original_tensor)  # tensor([[10,  2,  3], [ 4,  5,  6]])

2. 广播机制

广播机制使得不同形状的张量能够进行相同大小的运算。

a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
result = a + b
# result: tensor([[2, 3, 4],
#                 [3, 4, 5],
#                 [4, 5, 6]])

3. 分块和拼接

可以使用 split() 和 cat() 等函数进行分块和拼接。

tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])# 分割张量
split_tensors = torch.split(tensor, split_size_or_sections=2, dim=1)# 拼接张量
tensor_a = torch.tensor([[1, 2], [3, 4]])
tensor_b = torch.tensor([[5, 6], [7, 8]])
concat_tensor = torch.cat((tensor_a, tensor_b), dim=1)

4. 张量的复制

用于创建独立副本,clone() 和 detach() 是常用方法。

tensor = torch.tensor([1, 2, 3], requires_grad=True)
cloned_tensor = tensor.clone()
detached_tensor = tensor.detach()

7. 内存优化和管理

1. 稀疏张量

对于稀疏矩阵和张量,PyTorch 提供了稀疏张量表示,以便节省内存和计算资源。

indices = torch.tensor([[0, 1, 1], [2, 0, 2]])
values = torch.tensor([3, 4, 5], dtype=torch.float32)
sparse_tensor = torch.sparse_coo_tensor(indices, values, [2, 3])
print(sparse_tensor)

2. 内存释放

为了在训练和评估期间节省内存,可以释放不再需要的张量。

# 使用 del 语句手动删除对象
del tensor# 清空 GPU 切实可行的张量以释放内存
torch.cuda.empty_cache()

8. 应用实例

通过实际应用实例,可以更好地理解和掌握 PyTorch 张量的使用方式。

1. 线性回归

利用 PyTorch 张量实现简单的线性回归模型。

# 数据集
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 初始化参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)def model(x):return w * x + b# 损失函数
def loss_fn(y_pred, y):return ((y_pred - y) ** 2).mean()# 训练模型
learning_rate = 0.01
for epoch in range(1000):y_pred = model(x_train)loss = loss_fn(y_pred, y_train)loss.backward()with torch.no_grad():w -= learning_rate * w.gradb -= learning_rate * b.gradw.grad.zero_()b.grad.zero_()print(f'w: {w}, b: {b}')

2. 神经网络基础

张量在神经网络中的应用,是构建复杂模型的基础。

import torch.nn as nn# 简单的神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(1, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 1)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outmodel = SimpleNN()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练模型
for epoch in range(1000):y_pred = model(x_train)loss = criterion(y_pred, y_train)optimizer.zero_grad()loss.backward()optimizer.step()print(list(model.parameters()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/438616.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TARA分析方法论——威胁分析和风险评估方法

一、什么是TARA分析方法论 威胁分析和风险评估(Threat Analysis and Risk Assessment) 通过识别整车/项目的网络安全资产,分析其中的潜在的安全威胁,综合考虑威胁攻击可行性、危害影响等因素,识别出整车/项目可能存在…

Python并发编程(2)——初始Python多线程

左手编程,右手年华。大家好,我是一点,关注我,带你走入编程的世界。 公众号:一点sir,关注领取python编程资料 前言 什么是多线程? 为什么需要多线程? 多线程的优点和缺点&#xff1f…

前端规范工程-5:Git提交信息规范(commitlint + czg)

前面讲的都是在git提交之前的一些检查流程,然而我们git提交信息的时候,也应该是需要规范的。直接进入主题: 目录 需安装插件清单commitlint 介绍安装配置配置commit-msg钩子提交填写commit信息czg后续方式一:push触动build并上传…

Windows UAC权限详解以及因为权限不对等引发软件工具无法正常使用的实例分析

目录 ​1、什么是UAC? 2、微软为什么要设计UAC? 3、标准用户权限与管理员权限 4、程序到底以哪种权限运行?与哪些因素有关? 4.1、给程序设置以管理员权限运行的属性 4.2、当前登录用户的类型 4.3、如何通过代码判断某个进程…

2.1MyBatis——ORM对象关系映射

2.1MyBatis——ORM对象关系映射 1. 验证映射配置2.ResultType和ResultMap2.1ResultMap是最终的ORM依据2.2ResultType和ResultMap的使用区别 3.具体的转换逻辑3.1 TypeHandle类型转换 5.总结 概括的说,MyBatis中,对于映射关系的声明是由开发者在xml文件手…

手机USB连接不显示内部设备,设备管理器显示“MTP”感叹号,解决方案

进入小米驱动下载界面,等小米驱动下载完成后,解压此驱动文件压缩包。 5、小米USB驱动安装方法:右击“计算机”,从弹出的右键菜单中选择“管理”项进入。 6、在打开的“计算机管理”界面中,展开“设备管理器”项&…

【数据分享】2000—2023年我国省市县三级逐年植被覆盖度(FVC)数据(Shp/Excel格式)

之前我们分享过2000—2023年逐月植被覆盖度(FVC)栅格数据(可查看之前的文章获悉详情)和Excel和Shp格式的省市县三级逐月FVC数据(可查看之前的文章获悉详情),原始的逐月栅格数据来源于高吉喜学者…

深度学习:迁移学习

目录 一、迁移学习 1.什么是迁移学习 2.迁移学习的步骤 1、选择预训练的模型和适当的层 2、冻结预训练模型的参数 3、在新数据集上训练新增加的层 4、微调预训练模型的层 5、评估和测试 二、迁移学习实例 1.导入模型 2.冻结模型参数 3.修改参数 4.创建类&#xff…

GAN|对抗| 生成器更新|判别器更新过程

如上图所示,生成对抗网络存在上述内容: 真实数据集;生成器;生成器损失函数;判别器;判别器损失函数;生成器、判别器更新(生成器和判别器就是小偷和警察的关系,他们共用的…

kubernetes基础操作(pod生命周期)

pod生命周期 一、Pod生命周期 我们一般将pod对象从创建至终的这段时间范围称为pod的生命周期,它主要包含下面的过程: ◎pod创建过程 ◎运行初始化容器(init container)过程 ◎运行主容器(main container&#xff…

记录一次病毒启动脚本

在第一次下载软件时,目录中配了一个使用说明,说是需要通过start.bat 这个文件来启动程序,而这个 start.bat 就是始作俑者: 病毒作者比较狡猾,其中start.bat 用记事本打开是乱码,但是可以通过将这个批处理…

spring揭秘24-springmvc02-5个重要组件

文章目录 【README】【1】HanderMapping-处理器映射容器【1.1】HanderMapping实现类【1.1.1】SimpleUrlHandlerMapping 【2】Controller(二级控制器)【2.1】AbstractController抽象控制器(控制器基类) 【3】ModelAndView(模型与视…

java入门基础(一篇搞懂)

​ 如果您觉得这篇文章对您有帮助的话 欢迎您分享给更多人哦 感谢大家的点赞收藏评论,感谢您的支持!!! 首先给大家推荐比特博哥,java入门安装的JDk和IDEA社区版的安装视频 JDK安装与环境变量的配置 IDEA社区的安装与使…

帝国CMS系统开启https后,无法登陆后台的原因和解决方法

今天本地配置好了帝国CMS7.5,传去服务器后,使用http访问一切正常。但是当开启了https(SSL)后,后台竟然无法登陆进去了。 输入账号密码后,点击登陆,跳转到/e/admin/ecmsadmin.php就变成页面一片…

SpringBoot基础(三):Logback日志

SpringBoot基础系列文章 SpringBoot基础(一):快速入门 SpringBoot基础(二):配置文件详解 SpringBoot基础(三):Logback日志 目录 一、日志依赖二、日志格式1、记录日志2、默认输出格式3、springboot默认日志配置 三、日志级别1、基础设置2、…

golang-基础知识(流程控制)

1 条件判断if和switch 所有的编程语言都有这个if,表示如果满足条件就做某事,不满足就做另一件事,go中的if判断和其它语言的区别主要有以下两点 1. go里面if条件判断不需要括号 2. go的条件判断语句中允许声明一个变量,这个变量…

FPGA-UART串口接收模块的理解

UART串口接收模块 背景 在之前就有写过关于串口模块的文章——《串口RS232的学习》。工作后很多项目都会用到串口模块,又来重新理解一下FPGA串口接收的代码思路。 关于串口相关的参数,以及在文章《串口RS232的学习》中已有详细的描述,这里就…

单调队列与单调栈<2>——单调栈

单调栈的定义 单调递增栈 栈中元素从栈底到栈顶是递增的。 单调递减栈 栈中元素从栈底到栈顶是递减的。 单调栈的核心内容 我们从左到右遍历元素,构造单调栈(从栈顶到栈底递增或减):在 i 从左往右遍历的过程中,我…

手写堆排序

手写堆排序 摘要:本文记录使用go语言实现堆排序 堆的构建 堆性质: 对于每个小堆,父节点与两个子节点比较,父节点比左子节点大,也比右子节点大。 有五个数: 1,2,3,4,5 分别进行入栈。过程如下 (1) 堆为…

(作业)第三期书生·浦语大模型实战营(十一卷王场)--书生入门岛通关第3关Git 基础知识

任务编号 任务名称 任务描述 1 破冰活动 提交一份自我介绍。 2 实践项目 创建并提交一个项目。 破冰活动 提交一份自我介绍。 每位参与者提交一份自我介绍。 提交地址:https://github.com/InternLM/Tutorial 的 camp3 分支~ 安装并设置git 克隆仓库并…