PyTorch API 详细中文文档,基于PyTorch2.5


PyTorch API 详细中文文档

按模块分类,涵盖核心函数与用法示例


目录

  1. 张量操作 (Tensor Operations)
  2. 数学运算 (Math Operations)
  3. 自动求导 (Autograd)
  4. 神经网络模块 (torch.nn)
  5. 优化器 (torch.optim)
  6. 数据加载与处理 (torch.utils.data)
  7. 设备管理 (Device Management)
  8. 模型保存与加载
  9. 分布式训练 (Distributed Training)
  10. 实用工具函数

1. 张量操作 (Tensor Operations)

1.1 张量创建
函数描述示例
torch.tensor(data, dtype, device)从数据创建张量torch.tensor([1,2,3], dtype=torch.float32)
torch.zeros(shape)创建全零张量torch.zeros(2,3)
torch.ones(shape)创建全一张量torch.ones(5)
torch.rand(shape)均匀分布随机张量torch.rand(3,3)
torch.randn(shape)标准正态分布张量torch.randn(4,4)
torch.arange(start, end, step)创建等差序列torch.arange(0, 10, 2)[0,2,4,6,8]
torch.linspace(start, end, steps)线性间隔序列torch.linspace(0, 1, 5)[0, 0.25, 0.5, 0.75, 1]
1.2 张量属性
属性/方法描述示例
.shape张量维度x = torch.rand(2,3); x.shape → torch.Size([2,3])
.dtype数据类型x.dtype → torch.float32
.device所在设备x.device → device(type='cpu')
.requires_grad是否追踪梯度x.requires_grad = True
1.3 张量变形
函数描述示例
.view(shape)调整形状(不复制数据)x = torch.arange(6); x.view(2,3)
.reshape(shape)类似 view,但自动处理内存连续性x.reshape(3,2)
.permute(dims)调整维度顺序x = torch.rand(2,3,4); x.permute(1,2,0)
.squeeze(dim)去除大小为1的维度x = torch.rand(1,3); x.squeeze(0)shape [3]
.unsqueeze(dim)添加大小为1的维度x = torch.rand(3); x.unsqueeze(0)shape [1,3]

2. 数学运算 (Math Operations)

2.1 逐元素运算
函数描述示例
torch.add(x, y)加法torch.add(x, y)x + y
torch.mul(x, y)乘法torch.mul(x, y)x * y
torch.exp(x)指数运算torch.exp(torch.tensor([1.0]))[2.7183]
torch.log(x)自然对数torch.log(torch.exp(tensor([2.0])))[2.0]
torch.clamp(x, min, max)限制值范围torch.clamp(x, min=0, max=1)
2.2 矩阵运算
函数描述示例
torch.matmul(x, y)矩阵乘法x = torch.rand(2,3); y = torch.rand(3,4); torch.matmul(x, y)
torch.inverse(x)矩阵求逆x = torch.rand(3,3); inv_x = torch.inverse(x)
torch.eig(x)特征值分解eigenvalues, eigenvectors = torch.eig(x)
2.3 统计运算
函数描述示例
torch.sum(x, dim)沿维度求和x = torch.rand(2,3); torch.sum(x, dim=1)
torch.mean(x, dim)沿维度求均值torch.mean(x, dim=0)
torch.max(x, dim)沿维度求最大值values, indices = torch.max(x, dim=1)
torch.argmax(x, dim)最大值索引indices = torch.argmax(x, dim=1)

3. 自动求导 (Autograd)

3.1 梯度计算
函数/属性描述示例
x.backward()反向传播计算梯度x = torch.tensor(2.0, requires_grad=True); y = x**2; y.backward()
x.grad查看梯度值x.grad4.0(若 y = x²
torch.no_grad()禁用梯度追踪with torch.no_grad(): y = x * 2
detach()分离张量(不追踪梯度)y = x.detach()
3.2 梯度控制
函数描述
x.retain_grad()保留非叶子节点的梯度
torch.autograd.grad(outputs, inputs)手动计算梯度

示例

x = torch.tensor(3.0, requires_grad=True)  
y = x**3 + 2*x  
dy_dx = torch.autograd.grad(y, x)  # 返回 (torch.tensor(29.0),)  

4. 神经网络模块 (torch.nn)

4.1 层定义
描述示例
nn.Linear(in_features, out_features)全连接层layer = nn.Linear(784, 256)
nn.Conv2d(in_channels, out_channels, kernel_size)卷积层conv = nn.Conv2d(3, 16, kernel_size=3)
nn.LSTM(input_size, hidden_size)LSTM 层lstm = nn.LSTM(100, 50)
nn.Dropout(p=0.5)Dropout 层dropout = nn.Dropout(0.2)
4.2 激活函数
函数描述示例
nn.ReLU()ReLU 激活F.relu(x)nn.ReLU()(x)
nn.Sigmoid()Sigmoid 函数torch.sigmoid(x)
nn.Softmax(dim)Softmax 归一化F.softmax(x, dim=1)
4.3 损失函数
描述示例
nn.MSELoss()均方误差loss_fn = nn.MSELoss()
nn.CrossEntropyLoss()交叉熵损失loss = loss_fn(outputs, labels)
nn.BCELoss()二分类交叉熵loss_fn = nn.BCELoss()

5. 优化器 (torch.optim)

5.1 优化器定义
描述示例
optim.SGD(params, lr)随机梯度下降optimizer = optim.SGD(model.parameters(), lr=0.01)
optim.Adam(params, lr)Adam 优化器optimizer = optim.Adam(model.parameters(), lr=0.001)
optim.RMSprop(params, lr)RMSprop 优化器optimizer = optim.RMSprop(params, lr=0.01)
5.2 优化器方法
方法描述示例
optimizer.zero_grad()清空梯度optimizer.zero_grad()
optimizer.step()更新参数loss.backward(); optimizer.step()
optimizer.state_dict()获取优化器状态state = optimizer.state_dict()

6. 数据加载与处理 (torch.utils.data)

6.1 数据集类
类/函数描述示例
Dataset自定义数据集基类继承并实现 __len____getitem__
DataLoader(dataset, batch_size, shuffle)数据加载器loader = DataLoader(dataset, batch_size=64, shuffle=True)

自定义数据集示例

class MyDataset(Dataset):  def __init__(self, data, labels):  self.data = data  self.labels = labels  def __len__(self):  return len(self.data)  def __getitem__(self, idx):  return self.data[idx], self.labels[idx]  
6.2 数据预处理 (TorchVision)
from torchvision import transforms  transform = transforms.Compose([  transforms.Resize(256),          # 调整图像大小  transforms.ToTensor(),           # 转为张量  transforms.Normalize(mean=[0.5], std=[0.5])  # 标准化  
])  

7. 设备管理 (Device Management)

7.1 设备切换
函数/方法描述示例
.to(device)移动张量/模型到设备x = x.to('cuda:0')
torch.cuda.is_available()检查 GPU 是否可用if torch.cuda.is_available(): ...
torch.cuda.empty_cache()清空 GPU 缓存torch.cuda.empty_cache()

8. 模型保存与加载

函数描述示例
torch.save(obj, path)保存对象(模型/参数)torch.save(model.state_dict(), 'model.pth')
torch.load(path)加载对象model.load_state_dict(torch.load('model.pth'))
model.state_dict()获取模型参数字典params = model.state_dict()

9. 分布式训练 (Distributed Training)

函数/类描述示例
nn.DataParallel(model)单机多卡并行model = nn.DataParallel(model)
torch.distributed.init_process_group()初始化分布式训练需配合多进程使用

10. 实用工具函数

函数描述示例
torch.cat(tensors, dim)沿维度拼接张量torch.cat([x, y], dim=0)
torch.stack(tensors, dim)堆叠张量(新建维度)torch.stack([x, y], dim=1)
torch.split(tensor, split_size, dim)分割张量chunks = torch.split(x, 2, dim=0)

常见问题与技巧

  1. GPU 内存不足

    • 使用 batch_size 较小的值
    • 启用混合精度训练 (torch.cuda.amp)
    • 使用 torch.utils.checkpoint 节省内存
  2. 梯度爆炸/消失

    • 使用梯度裁剪:torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    • 调整权重初始化方法
  3. 模型推理模式

    model.eval()  # 关闭 Dropout 和 BatchNorm 的随机性  
    with torch.no_grad():  outputs = model(inputs)  
    

文档说明

  • 本文档基于 PyTorch 2.5 编写,部分 API 可能不兼容旧版本。
  • 更详细的参数说明请参考 PyTorch 官方文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/9565.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(done) MIT6.S081 2023 学习笔记 (Day6: LAB5 COW Fork)

网页:https://pdos.csail.mit.edu/6.S081/2023/labs/cow.html 任务1:Implement copy-on-write fork(hard) (完成) 现实中的问题如下: xv6中的fork()系统调用会将父进程的用户空间内存全部复制到子进程中。如果父进程很大,复制过程…

如何将xps文件转换为txt文件?xps转为pdf,pdf转为txt,提取pdf表格并转为txt

文章目录 xps转txt方法一方法二 pdf转txt整页转txt提取pdf表格,并转为txt 总结另外参考XPS文件转换为TXT文件XPS文件转换为PDF文件PDF文件转换为TXT文件提取PDF表格并转为TXT示例代码(部分) 本文测试代码已上传,路径如下&#xff…

C++,STL,【目录篇】

文章目录 一、简介二、内容提纲第一部分:STL 概述第二部分:STL 容器第三部分:STL 迭代器第四部分:STL 算法第五部分:STL 函数对象第六部分:STL 高级主题第七部分:STL 实战应用 三、写作风格四、…

[STM32 - 野火] - - - 固件库学习笔记 - - -十三.高级定时器

一、高级定时器简介 高级定时器的简介在前面一章已经介绍过,可以点击下面链接了解,在这里进行一些补充。 [STM32 - 野火] - - - 固件库学习笔记 - - -十二.基本定时器 1.1 功能简介 1、高级定时器可以向上/向下/两边计数,还独有一个重复计…

Mybatis是如何进行分页的?

大家好,我是锋哥。今天分享关于【Mybatis是如何进行分页的?】面试题。希望对大家有帮助; Mybatis是如何进行分页的? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 MyBatis 实现分页的方式有很多种,最常见…

四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用)

四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用) 文章目录 四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用)2.hash 哈希表常用指令(详细讲解说明)2.1 hset …

编译dpdk19.08.2中example时一系列报错解决

dpdk19.08编译过程全解 dpdk 介绍问题描述编译过程执行Step 1报错一解决方式 报错二解决方式 继续执行Step 248的时候报错 49没有修改成功输入60退出 使用过程执行make报错一解决方式 继续make报错二解决方式 继续make执行生成文件helloworld报错三解决方式 执行make 完成参考链…

openeuler 22.03 lts sp4 使用 cri-o 和 静态 pod 的方式部署 k8s-v1.32.0 高可用集群

前情提要 整篇文章会非常的长…可以选择性阅读,另外,这篇文章是自己学习使用的,用于生产,还请三思和斟酌 静态 pod 的部署方式和二进制部署的方式是差不多的,区别在于 master 组件的管理方式是 kubectl 还是 systemctl有 kubeadm 工具,为什么还要用静态 pod 的方式部署?…

渗透测试之WAF规则触发绕过规则之规则库绕过方式

目录 Waf触发规则的绕过 特殊字符替换空格 实例 特殊字符拼接绕过waf Mysql 内置得方法 注释包含关键字 实例 Waf触发规则的绕过 特殊字符替换空格 用一些特殊字符代替空格,比如在mysql中%0a是换行,可以代替空格 这个方法也可以部分绕过最新版本的…

C# dataGridView1获取选中行的名字

在视觉项目中编写的框架需要能够选择产品或复制产品等方便后续换型,视觉调试仅需调试相机图像、调试视觉相关参数、标定,再试跑调试优化参数。 C# dataGridView1 鼠标点击某一行能够计算出是那一行 使用CellMouseClick事件 首先,在Form的构造…

Ubuntu介绍、与centos的区别、基于VMware安装Ubuntu Server 22.04、配置远程连接、安装jdk+Tomcat

目录 ?编辑 一、Ubuntu22.04介绍 二、Ubuntu与Centos的区别 三、基于VMware安装Ubuntu Server 22.04 下载 VMware安装 1.创建新的虚拟机 2.选择类型配置 3.虚拟机硬件兼容性 4.安装客户机操作系统 5.选择客户机操作系统 6.命名虚拟机 7.处理器配置 8.虚拟机内存…

Linux基础指令

基本文件操作 补充: “cd -” 可以前往刚才所在目录 “ls 文件路径” 列举指定路径的文件 “ls -a”列出隐藏文件 “ls -l”可以缩写为“ll” 周边概念 读取操作 “cat 文件名”阅读文本文件内容,可以使用Tab键补全文件…

【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(三)

目录 1 -> 生命周期 1.1 -> 应用生命周期 1.2 -> 页面生命周期 2 -> 资源限定与访问 2.1 -> 资源限定词 2.2 -> 资源限定词的命名要求 2.3 -> 限定词与设备状态的匹配规则 2.4 -> 引用JS模块内resources资源 3 -> 多语言支持 3.1 -> 定…

【JavaWeb06】Tomcat基础入门:架构理解与基本配置指南

文章目录 🌍一. WEB 开发❄️1. 介绍 ❄️2. BS 与 CS 开发介绍 ❄️3. JavaWeb 服务软件 🌍二. Tomcat❄️1. Tomcat 下载和安装 ❄️2. Tomcat 启动 ❄️3. Tomcat 启动故障排除 ❄️4. Tomcat 服务中部署 WEB 应用 ❄️5. 浏览器访问 Web 服务过程详…

C语言练习(29)

13个人围成一圈&#xff0c;从第1个人开始顺序报号1、2、3。凡报到“3”者退出圈子&#xff0c;找出最后留在圈子中的人原来的序号。本题要求用链表实现。 #include <stdio.h> #include <stdlib.h>// 定义链表节点结构体 typedef struct Node {int num;struct Nod…

简要介绍C语言和c++的共有变量,以及c++特有的变量

在C语言和C中&#xff0c;变量是用来存储数据的内存位置&#xff0c;它们的使用方式和特性在两种语言中既有相似之处&#xff0c;也有不同之处。以下分别介绍C语言和C的共有变量以及C特有的变量。 C语言和C的共有变量 C语言和C都支持以下类型的变量&#xff0c;它们在语法和基…

【UE插件】Sphinx关键词语音识别

视频教程&#xff1a; Unreal Engine - Speech Recognition - Free Pluginhttps://www.youtube.com/watch?vKBcXNnSdWog&t622s 官方教程&#xff1a; Sphinx: Speech Recognition Plugin | Unreal Engine Community Wikihttps://unrealcommunity.wiki/speech-recognition…

图漾相机——C++语言属性设置

文章目录 前言1.SDK API功能介绍1.1 Device组件下的API测试1.1.1 相机工作模式设置&#xff08;TY_TRIGGER_PARAM_EX&#xff09;1.1.2 TY_INT_FRAME_PER_TRIGGER1.1.3 TY_INT_PACKET_DELAY1.1.4 TY_INT_PACKET_SIZE1.1.5 TY_BOOL_GVSP_RESEND1.1.6 TY_BOOL_TRIGGER_OUT_IO1.1.…

Spring AI 在微服务中的应用:支持分布式 AI 推理

1. 引言 在现代企业中&#xff0c;微服务架构 已成为开发复杂系统的主流方式&#xff0c;而 AI 模型推理 也越来越多地被集成到业务流程中。如何在分布式微服务架构下高效地集成 Spring AI&#xff0c;使多个服务可以协同完成 AI 任务&#xff0c;并支持分布式 AI 推理&#x…

研发的立足之本到底是啥?

0 你的问题&#xff0c;我知道&#xff01; 本文深入T型图“竖线”的立足之本&#xff1a;专业技术 技术赋能业务能力。研发在学习投入精力最多&#xff0c;也误区最多。 某粉丝感发展遇到瓶颈&#xff0c;项目都会做&#xff0c;但觉无提升&#xff0c;想跳槽。于是&#x…