pytorch常用的模块函数汇总(1)

目录

torch:核心库,包含张量操作、数学函数等基本功能

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

torch.autograd:自动求导模块,用于计算张量的梯度


torch:核心库,包含张量操作、数学函数等基本功能

  1. torch.tensor(): 创建张量
  2. torch.zeros()torch.ones(): 创建全零或全一张量
  3. torch.rand(): 创建随机张量
  4. torch.from_numpy(): 从 NumPy 数组创建张量
  5. torch.add()torch.sub()torch.mul()torch.div(): 加法、减法、乘法、除法

  6. torch.mm()torch.matmul(): 矩阵乘法

  7. torch.exp()torch.log()torch.sin()torch.cos(): 指数、对数、正弦、余弦等数学函数

  8. torch.index_select(): 按索引选取张量的子集

  9. torch.masked_select(): 根据掩码选取张量的子集

    切片操作:类似 Python 中的列表切片操作,如 tensor[2:5]
  10. torch.view()torch.reshape(): 改变张量的形状

  11. torch.squeeze()torch.unsqueeze(): 压缩或扩展张量的维度

  12. torch.mean()torch.sum()torch.max()torch.min(): 计算张量均值、和、最大值、最小值等

  13. torch.broadcast_tensors(): 对张量进行广播操作

  14. torch.cat(): 拼接张量

  15. torch.stack(): 堆叠张量

  16. torch.split(): 分割张量

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

  • 神经网络层

    • torch.nn.Linear(in_features, out_features): 全连接层,进行线性变换。
    • torch.nn.Conv2d(in_channels, out_channels, kernel_size): 2D卷积层。
    • torch.nn.MaxPool2d(kernel_size): 2D 最大池化层。
    • torch.nn.ReLU(): ReLU 激活函数。
    • torch.nn.Sigmoid(): Sigmoid 激活函数。
    • torch.nn.Dropout(p): Dropout 层,用于防止过拟合。

备注:Sigmoid 激活函数是一种常用的非线性激活函数,其作用可以总结如下:

将输入映射到 (0, 1) 范围内:输出范围在 0 到 1 之间,可以将任意实数输入映射到 0 到 1 之间。这种特性在某些情况下很有用,比如对于二分类任务,Sigmoid 函数的输出可以被解释为样本属于正类的概率。

引入非线性变换: Sigmoid 函数是一种非线性函数,可以引入神经网络的非线性变换能力,使得神经网络可以学习更加复杂的模式和关系。在深度神经网络中,非线性激活函数的使用可以帮助神经网络学习非线性模式,提高网络的表达能力。

输出平滑且连续: Sigmoid 函数具有平滑的 S 形曲线,在定义域内都是可导的,这使得在反向传播算法中计算梯度变得相对容易。这一点对于神经网络的训练至关重要。

  • 损失函数

torch.nn.CrossEntropyLoss(): 交叉熵损失函数,常用于多分类问题。

交叉熵损失函数用于衡量两个概率分布之间的差异,通常用于多分类任务中。在神经网络的多分类任务中,输入模型的输出是一个概率分布,表示每个类别的预测概率,而交叉熵损失函数则用于比较这个预测概率分布与实际标签的分布之间的差异。

torch.nn.CrossEntropyLoss() 来计算交叉熵损失函数,它会自动将模型的输出通过 Softmax 函数转换为概率分布,并计算交叉熵损失。

torch.nn.MSELoss(): 均方误差损失函数,常用于回归问题。

均方误差损失函数用于衡量模型输出与实际目标之间的差异,通常在回归任务中使用。该损失函数计算预测值与真实值之间的平方差,并将所有样本的平方差求平均得到最终的损失值。

  • 优化器

    • torch.optim.SGD(model.parameters(), lr=learning_rate): 随机梯度下降优化器。
    • torch.optim.Adam(model.parameters(), lr=learning_rate): Adam 优化器。
  • 模型定义相关

    • torch.nn.Module: 所有神经网络模型的基类,需要继承这个类。
    • model.forward(input_tensor): 定义前向传播。
  • 数据处理相关

    • torch.utils.data.Dataset: PyTorch 数据集的基类,需要自定义数据集时使用。
    • torch.utils.data.DataLoader(dataset, batch_size, shuffle): 数据加载器,用于批量加载数据。
  • torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

  • 优化器(Optimizer)类

    • torch.optim.SGD(params, lr=0.01, momentum=0, weight_decay=0):随机梯度下降优化器,实现了带动量的随机梯度下降
    • torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):Adam 优化器,结合了动量方法和 RMSProp 方法,通常在深度学习中表现良好。
    • torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0):Adagrad 优化器,自适应地为参数分配学习率。它根据参数的历史梯度信息对每个参数的学习率进行调整。这意味着对于不同的参数,Adagrad可以为其分配不同的学习率,从而更好地适应参数的更新需求
    • torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):RMSprop 优化器,有效地解决了 Adagrad 学习率下降较快的问题(RMSprop对梯度平方项进行指数加权平均)。

备注:

1.在优化算法中,momentum(动量)是一种用于加速模型训练的技巧。动量项的引入旨在解决标准随机梯度下降在训练过程中可能遇到的震荡和收敛速度慢的问题。

动量项的引入可以帮助优化算法在参数更新时更好地利用之前的更新方向,从而在一定程度上减少参数更新的波动,加快收敛速度,并有助于跳出局部极小值。具体来说,动量项在参数更新时会考虑之前的更新方向,并对当前的更新方向进行一定程度的调整。

在 PyTorch 的 torch.optim.SGD 中,动量可以通过设置 momentum 参数来控制。通常情况下,动量的取值范围在 0 到 1 之间,常见的默认取值为 0.9。当动量设为0.9时,每次迭代,都会保留上一次速度的 90%,并使用当前梯度微调最终的更新方向。

总结来说,动量项的引入可以提高随机梯度下降的稳定性和收敛速度,有助于在训练神经网络时更快地找到较优的解。

2. 

在优化算法中,weight decay(权重衰减)是一种用于控制模型参数更新的正则化技术。权重衰减通过在优化过程中对参数进行惩罚,防止其取值过大,从而有助于降低过拟合的风险。

具体来说,在SGD中的weight_decay参数是对模型的权重进行L2正则化,即在计算梯度时额外增加一个关于参数的惩罚项。这个惩罚项会使得优化算法更倾向于选择较小的权重值,从而降低模型的复杂度,减少过拟合的风险。

在PyTorch中,torch.optim.SGD中的weight_decay参数用于控制权重衰减的程度。通常情况下,weight_decay的取值为一个小的正数,比如 0.001 或 0.0001。设置了weight_decay之后,在计算梯度时会额外考虑到权重的惩罚项,从而影响参数的更新方式。

总结来说,权重衰减是一种正则化技术,通过对模型参数的惩罚来控制模型的复杂度,减少过拟合的风险,提高模型的泛化能力。

3.

  • L1正则化:L1正则化会给模型的损失函数添加一个关于权重绝对值的惩罚项,即L1范数(权重的绝对值之和)。在梯度下降过程中,L1正则化会导致部分权重直接变为0,因此可以实现稀疏性,有特征选择的效果。L1正则化倾向于产生稀疏的权重矩阵,可以用于特征选择和降维。

L1正则化的惩罚项是模型权重的L1范数,即权重的绝对值之和。在优化过程中,为了最小化损失函数并减少正则化项的影响,优化算法会尝试将权重调整到较小的值。由于L1正则化的几何形状在坐标轴上拐角处就会与坐标轴相交,这就导致了在坐标轴上许多点都是对称的,因此在这些点上的梯度不唯一。这意味着在这些对称点上,优化算法更有可能将权重调整为0,从而导致稀疏性。

  • L2正则化:L2正则化会给模型的损失函数添加一个关于权重平方的惩罚项,即L2范数的平方(权重的平方和)。在梯度下降过程中,L2正则化会使得权重都变得比较小,但不会直接导致稀疏性。L2正则化对异常值比较敏感,因为它会平方每个权重,使得异常值对损失函数的影响更大。

  • 总的来说,L1正则化和L2正则化都是常用的正则化技术,它们在模型训练过程中都有助于控制模型的复杂度,减少过拟合的风险。选择使用哪种正则化方法通常取决于具体的问题和数据特点,以及对模型稀疏性的需求。在实际应用中,有时也会将L1和L2正则化结合起来,形成弹性网络正则化(Elastic Net regularization),以兼顾两种正则化方法的优势。

4. 

betas 是 Adam 算法中的两个超参数之一,它控制了梯度的一阶矩估计和二阶矩估计的指数衰减率。betas 是一个长度为2的元组,通常形式为 (beta1, beta2)。在 Adam 算法中,beta1 控制了一阶矩估计(梯度的均值)的衰减率,beta2 控制了二阶矩估计(梯度的平方的均值)的衰减率。

通常情况下,beta1 的默认值为 0.9,beta2 的默认值为 0.999。这意味着在每次迭代中,一阶矩估计将保留当前梯度的 90%,而二阶矩估计将保留当前梯度的平方的 99.9%。这些衰减率的选择使得 Adam 算法能够在训练过程中自适应地调整学习率,并对梯度的变化做出快速或缓慢的响应,从而更有效地更新模型参数。

总之,betas 参数在 Adam 算法中起着调节梯度一阶和二阶矩估计衰减率的作用,通过合理设置 betas 可以影响算法的收敛性和稳定性。

  • 调整学习率的函数

    • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)根据给定的函数 lr_lambda 调整学习率。
    • torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1):每个 step_size 个 epoch 将学习率降低为原来的 gamma 倍。
    • torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1):在指定的里程碑上将学习率降低为原来的 gamma 倍。
  • 其他常用函数

    • zero_grad():用于将模型参数的梯度清零,通常在每个 batch 后调用。
    • step(closure):用于执行单步优化器的更新,需要传入一个闭包函数 closure
    • state_dict() 和 load_state_dict():用于保存和加载优化器的状态字典,方便恢复训练。
  • torch.autograd:自动求导模块,用于计算张量的梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/288378.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工业物联网关的应用及相关产品-天拓四方

随着科技的飞速发展,智能制造业已成为工业领域的转型方向。在这一转变中,工业物联网关发挥着至关重要的作用。作为连接物理世界与数字世界的桥梁,工业物联网关不仅实现了设备与设备、设备与云平台之间的互联互通,更通过实时数据采…

Fabric Measurement

Fabric Measurement 布料测量

低功耗、低成本 NAS 的可能性

使用现状:多台工作电脑,家里人手一台,还在两个住处 有好几台工作电脑,不同电脑有不同的用途,最大的问题就是各个电脑上文件的同步问题,这里当然就需要局域网里的公共文件夹,在NAS的问题上查了网…

FreeRTOS(三)

第二部分 事件组 一、事件组的简介 1、事件 事件是一种实现任务间通信的机制,主要用于实现多任务间的同步,但事件通信只能是事件类型的通信,无数据传输。其实事件组的本质就是一个整数(16/32位)。可以是一个事件发生唤醒一个任务&#xff…

【C语言进阶篇】编译和链接

【C语言进阶篇】编译和链接 🥕个人主页:开敲🍉 🔥所属专栏:C语言🍓 🌼文章目录🌼 编译环境与运行环境 1. 翻译环境 2. 编译环境:预编译(预处理)编…

Mac上的Gatekeeper系统跟运行时保护

文章目录 问题:无法打开“xxx.xxx”,因为无法验证开发者。macOS无法验证此App是否包含恶意软件。如何解决? 参考资料门禁运行时保护 问题:无法打开“xxx.xxx”,因为无法验证开发者。macOS无法验证此App是否包含恶意软件…

解析SpringBoot自动装配原理前置知识:解析条件注释的原理

什么是自动装配? Spring提供了向Bean中自动注入依赖的这个功能,这个过程就是自动装配。 SpringBoot的自动装配原理基于大量的条件注解ConditionalOnXXX,因此要先来了解一下条件注解相关的源码。 以ConditionalOnClass为例 首先来查看Conditi…

兼顾陪读|本科学历律师自费赴美国加州大学伯克利分校访学

S律师拟陪同孩子赴海外就读,决定以访问学者身份,申请美国J类签证出国以兼顾陪读。因本科学历,无文章且有地域要求,自己申请无果后做了全权委托。为此我们酌情制定了三条申请策略,最终落实加州大学伯克利分校的访学职位…

NSString有哪些创建对象的方法?创建的对象分别存储在什么区域?

NSString有哪些创建对象的方法?创建的对象分别存储在什么区域? 一般通过NSString创建对象的方法有: NSString *string1 "123";NSString *string2 [[NSString alloc] initWithString:"123"];NSString *string3 [NSSt…

解决方案:如何安装neo4j软件

文章目录 一、安装JDK二、安装neo4j 一、安装JDK 第一步先安装JDK,因为neo4j环境需要JDK,过程比较多,截图如下: 安装JDK网址 https://www.oracle.com/java/technologies/downloads winR,输入cmd,再输入j…

Leetcode70. 爬楼梯(动态规划)

Leetcode原题 Leetcode70. 爬楼梯 标签 记忆化搜索 | 数学 | 动态规划 题目描述 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢?示例 1:输入:n 2 输出:2 解…

数据分析之POWER Piovt透视表分析与KPI设置

将几个数据表之间进行关联 生成数据透视表 超级透视表这里的字段包含子字段 这三个月份在前面的解决办法 1.选中这三个月份,鼠标可移动的时候移动到后面 2.在原数据进行修改 添加列获取月份,借助month的函数双击日期 选择月份这列----按列排序-----选择月…

C++ 控制语句(一)

一 顺序结构 程序的基本结构有三种: 顺序结构、分支结构、循环结构 大量的实际问题需要通过各种控制流程来解决。 1.1 顺序结构 1.2 简单语句和复合语句 二 循环 2.1 for循环 语句流程图 注意:使用for语句的灵活性 三 while语句 四 do while语句

【LLM多模态】Cogvlm图生文模型结构和训练流程

note Cogvlm的亮点: 当前主流的浅层对齐方法不佳在于视觉和语言信息之间缺乏深度融合,而cogvlm在attention和FFN layers引入一个可训练的视觉专家模块,将图像特征与文本特征分别处理,并在每一层中使用新的QKV矩阵和MLP层。通过引…

【LaTeX】7实现章节跳转

使用 LaTeX 实现章节跳转 写在最前面1. 引入 hyperref 包2. 标记章节3. 引用章节示例代码注意 小技巧总结 🌈你好呀!我是 是Yu欸 🌌 2024每日百字篆刻时光,感谢你的陪伴与支持 ~ 🚀 欢迎一起踏上探险之旅,…

Vue 3中ref和reactive的区别

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

碳课堂|什么是碳资产?企业如何进行碳资产管理?

碳资产是绿色资产的重要类别,在全球气候变化日益严峻的背景下备受关注。在“双碳”目标下,碳资产管理是企业层面实现碳减排目标和低碳转型的关键。 一、什么是碳资产? 碳资产是以碳减排为基础的资产,是企业为了积极应对气候变化&…

就业班 第二阶段 2401--3.25 day5 mycat读写分离

[TOC] 启动并更改临时密码 [rootmysql1~]# systemctl start mysqld && passwdgrep password /var/log/mysqld.log | awk END{ print $NF} && mysqladmin -p"$passwd" password Qwer123..; MyCAT读写分离 Mycat 是一个开源的数据库系统,但…

遇到了问题,Firepower 2140配置带外IP时报错 commit-buffer failed

onsite we have a cisco firepower 2140 device which run ASA as we try to modify the 2140 OOB mgmt ip by CLI, we got an error why ? 经过查询发现,需要进入ASA里面打上以下这条命令,并重启ASA 1 修改模式并重启 ciscoasa# configure termina…

idea使用git笔记

1.创建分支和切换分支 创建分支 切换分支 2.把新创建的分支提交到远程服务器上(注:如果没有提交的,随便找个文件修改再提交) (1)切换到要提交的分支,add (2)commit (3)push 3.在自己分支修改代码及提交到自己的远…