【PyTorch】torch.optim介绍

文章目录

  • PyTorch torch.optim介绍
  • 1. torch.optim主要功能
  • 2. 常见的优化算法
    • 2.1 SGD(随机梯度下降)
    • 2.2 Momentum(带动量的SGD)
    • 2.3 Adam(自适应矩估计)
    • 2.4 RMSprop
    • 2.5 Adagrad
  • 3. 优化器的核心操作
    • 3.1 初始化优化器
    • 3.2 `optimizer.zero_grad()`
    • 3.3 `loss.backward()`
    • 3.4 `optimizer.step()`
    • 3.5 梯度裁剪
  • 4. 学习率调整(`lr_scheduler`)
    • 4.1 StepLR
    • 4.2 ReduceLROnPlateau
    • 4.3 ExponentialLR
  • 5. 完整的训练过程示例
  • 6. 总结

PyTorch torch.optim介绍

torch.optim 是 PyTorch 中用于优化神经网络模型参数的模块,它实现了多种常见的优化算法(如 SGD、Adam、RMSprop 等),通过计算损失函数对参数的梯度并根据梯度更新模型的权重。

1. torch.optim主要功能

  • 优化算法的实现:提供多种优化算法,如常见的 SGDAdamRMSprop 等,适用于不同类型的模型和任务。
  • 动态学习率调整:支持动态调整学习率的策略(如 lr_scheduler),在训练过程中提高效率。
  • 参数更新:通过计算梯度并更新模型的参数,优化器会优化模型的权重,以最小化损失函数。

2. 常见的优化算法

2.1 SGD(随机梯度下降)

SGD 是最经典的优化算法,适用于大多数简单的深度学习问题。它通过更新参数的方式,沿着负梯度方向逐步减小损失。

import torch
import torch.optim as optim
import torch.nn as nn# 假设定义了一个简单的神经网络模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型实例
model = SimpleNN()# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 假设训练数据
input_data = torch.randn(64, 784)  # 假设64个样本,每个样本784维
labels = torch.randint(0, 10, (64,))  # 64个标签,10个类别# 训练过程
for epoch in range(10):optimizer.zero_grad()  # 清除梯度output = model(input_data)  # 前向传播loss = loss_fn(output, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch+1}/10], Loss: {loss.item()}')

2.2 Momentum(带动量的SGD)

Momentum 方法是在每次更新时加入前一步的梯度信息,这样能加速收敛并减少波动。

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

在上面的例子中,momentum=0.9 表示前一步更新的梯度贡献占 90%。

2.3 Adam(自适应矩估计)

Adam 是一种自适应优化算法,它结合了 MomentumRMSprop 的优点。Adam 会根据每个参数的均值和方差动态调整学习率。

optimizer = optim.Adam(model.parameters(), lr=0.001)

Adam 的优点:

  • 自适应学习率:每个参数都有自己的学习率。
  • 收敛速度快:通常在较少的训练步骤内能达到较好的效果。

2.4 RMSprop

RMSprop 是另一种自适应学习率的优化算法,特别适合处理循环神经网络(RNN)等任务。

optimizer = optim.RMSprop(model.parameters(), lr=0.01)

RMSprop 通过调整每个参数的学习率来避免某些参数更新过快或过慢。

2.5 Adagrad

Adagrad 是另一种自适应优化算法,它在每个参数的学习率上进行调整,使得稀疏数据的特征能够快速收敛。

optimizer = optim.Adagrad(model.parameters(), lr=0.01)

Adagrad 的主要特点是它对每个参数有独立的学习率,参数的更新根据梯度大小自适应调整。

3. 优化器的核心操作

3.1 初始化优化器

初始化优化器时,通常需要传入模型的参数和学习率。例如:

optimizer = optim.Adam(model.parameters(), lr=0.001)

model.parameters() 返回模型的所有可学习参数,lr=0.001 是优化器的学习率。

3.2 optimizer.zero_grad()

在每次更新参数前,需要清除之前的梯度,因为 PyTorch 中的梯度是累积的。可以使用 optimizer.zero_grad() 来清空梯度。

optimizer.zero_grad()

3.3 loss.backward()

计算反向传播,PyTorch 会根据损失函数的梯度自动计算每个参数的梯度。

loss.backward()

3.4 optimizer.step()

通过梯度信息更新模型的参数。调用 optimizer.step() 后,优化器会使用当前计算的梯度来更新模型的权重。

optimizer.step()

3.5 梯度裁剪

为了防止梯度爆炸问题,通常会进行梯度裁剪操作。可以使用 torch.nn.utils.clip_grad_norm_ 来对梯度进行裁剪。

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

这会将所有参数的梯度裁剪到最大范数 1.0

4. 学习率调整(lr_scheduler

PyTorch 提供了多个学习率调整策略,可以帮助在训练过程中动态调整学习率,以便模型更好地收敛。

4.1 StepLR

StepLR 会在每隔一定步数后降低学习率,通常用于训练时逐渐减小学习率,防止过拟合。

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

step_size 个 epoch 后,学习率会乘以 gamma,例如每 10 个 epoch 后学习率会变为原来的 0.1。

4.2 ReduceLROnPlateau

ReduceLROnPlateau 根据验证集的性能来调整学习率。如果模型在一定的 epoch 内未能改善,学习率就会减小。

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1)
  • patience=5:如果验证损失在 5 个 epoch 内没有下降,学习率就会减少。
  • factor=0.1:每次减少学习率时,将其乘以 0.1

4.3 ExponentialLR

ExponentialLR 通过指数衰减来调整学习率。

scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

在每个 epoch 后,学习率会乘以 gamma=0.99,实现指数衰减。

5. 完整的训练过程示例

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的神经网络模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型实例
model = SimpleNN()# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 定义学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)# 模拟训练过程
for epoch in range(20):optimizer.zero_grad()  # 清空梯度input_data = torch.randn(64, 784)  # 假设的输入数据labels = torch.randint(0, 10, (64,))  # 假设的标签outputs = model(input_data)loss = loss_fn(outputs, labels)loss.backward()  # 反向传播optimizer.step()  # 更新参数# 每5个epoch调整一次学习率scheduler.step()print(f'Epoch [{epoch+1}/20], Loss: {loss.item()}, Learning Rate: {optimizer.param_groups[0]["lr"]}')

6. 总结

  • 优化器:PyTorch 提供了多种优化算法,如 SGD、Adam、RMSprop、Adagrad 等。根据任务选择合适的优化器。
  • 学习率调整torch.optim.lr_scheduler 提供了多种动态调整

学习率的策略,帮助模型更好地收敛。

  • 梯度裁剪:防止梯度爆炸,保证训练过程的稳定性。

通过合理的优化器选择和学习率调整,可以大大提高模型的训练效率和性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/19292.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算四个锚点TOA定位中GDOP的详细步骤和MATLAB例程

该MATLAB代码演示了在三维空间中,使用四个锚点的TOA(到达时间)定位技术计算几何精度衰减因子(GDOP)的过程。如需帮助,或有导航、定位滤波相关的代码定制需求,请联系作者 文章目录 DOP计算原理MATLAB例程运行结果示例关键点说明扩展方向另有文章: 多锚点Wi-Fi定位和基站…

基于Spring Boot+Vue的宠物服务管理系统(源码+文档)

项目简介 宠物服务管理系统实现了以下功能: 基于Spring BootVue的宠物服务管理系统的主要使用者分为用户管理模块,由于系统运行在互联网络中,一些游客或者病毒恶意进行注册,产生大量的垃圾用户信息,管理员可以对这些…

jenkins服务启动-排错

服务状态为active (exited) 且进程不在 查看/etc/rc.d/init.d/jenkins配置 获取配置参数 [rootfy-jenkins-prod jenkins]# cat /etc/rc.d/init.d/jenkins | grep -v #JENKINS_WAR"/usr/lib/jenkins/jenkins.war" test -r "$JENKINS_WAR" || { echo "…

vue3 分析总结响应式丢失问题原因(二)

上一篇文件理解了响应式对象应用原理了。公式: 响应式对象 代理 触发器。 但是实际使用结果和预期还是不一致。具体现象是数据修改了,但是并没有实现响应式更新界面。即出现了响应式丢失现象。 一、什么情况下对象的响应式会丢失? 一般网…

【网络】协议与网络版计算器

协议与网络版计算器 文章目录 1.协议的概念 1.1序列化与反序列化 2.网络版计算器 2.1封装套接字2.2协议定制 2.2.1Jsoncpp2.2.2报文处理 2.3会话层:TcpServer2.4应用层:Calculate2.5表示层:Service2.6应用层、表示层和会话层->应用层 …

C# 添加图标

一、前言 为应用程序添加图标是优化用户界面、提升应用辨识度的重要操作。合适的图标能帮助用户快速识别和区分不同应用,增强应用的易用性和专业性。 本指南旨在为你提供详细、易懂的步骤,教你如何为应用程序的窗体添加图标。从图标素材的获取到具体的…

使用新版本golang项目中goyacc依赖问题的处理

背景 最近项目使用中有用到go mod 和 goyacc工具。goyacc涉及到编译原理的词法分析,文法分析等功能,可以用来生成基于golang的语法分析文件。本期是记录一个使用中遇到的依赖相关的问题。因为用到goyacc,需要生成goyacc的可执行文件。 而项目…

WPS的AI助手进化跟踪(灵犀+插件)

Ver V0.0 250216: 如何给WPS安装插件用以支持其他大模型LLM V0.1 250217: WPS的灵犀AI现在是DeepSeek R1(可能是全参数671B) 前言 WPS也有内置的AI,叫灵犀,之前应是自已的LLM模型,只能说是属于“能用,有好过无”,所…

计算机视觉:卷积神经网络(CNN)基本概念(一)

第一章:计算机视觉中图像的基础认知 第二章:计算机视觉:卷积神经网络(CNN)基本概念(一) 第三章:计算机视觉:卷积神经网络(CNN)基本概念(二) 第四章:搭建一个经典的LeNet5神经网络 一、引言 卷积神经网络&…

rabbitmq详解

有需要的直接看狂神的视频,讲得很好 简介 RabbitMQ 是一个开源的 消息队列中间件,实现了 AMQP(Advanced Message Queuing Protocol,先进消息队列协议)。它允许 应用程序、服务、系统之间异步地传递消息,并…

moveable 一个可实现前端海报编辑器的 js 库

目录 缘由-胡扯本文实验环境通用流程1.基础移动1.1 基础代码1.1.1 data-* 解释 1.2 操作元素创建1.3 css 修饰1.4 cdn 引入1.5 js 实现元素可移动1.6 图片拖拽2.缩放3.旋转4.裁剪 懒得改文案了,海报编辑器换方案了,如果后面用别的再更。 缘由-胡扯 导火…

计算机视觉中图像的基础认知

第一章:计算机视觉中图像的基础认知 第二章:计算机视觉:卷积神经网络(CNN)基本概念(一) 第三章:计算机视觉:卷积神经网络(CNN)基本概念(二) 第四章:搭建一个经典的LeNet5神经网络 一、图像/视频的基本属性…

java八股文-mysql

1. 索引 1.1 什么是索引 索引(index)是帮助Mysql高效获取数据的数据结构(有序).提高数据的检索效率,降低数据库的IO成本(不需要全表扫描).通过索引列对数据进行排序,降低数据排序成本,降低了CPU的消耗. 1.2 mysql索引使用的B树? 1. 没有使用二叉树,最坏情况o&…

Next.js【详解】CSS 样式方案

全局样式 Global CSS 默认已创建,即 src\app\globals.css,可根据需要修改 默认在全局布局中导入 src\app\layout.tsx import "./globals.css";组件样式 CSS Modules 新建文件 src\app\test\styles.module.css .red {color: red;}导入目标页面…

彻底解决Idea控制台中文乱码问题

中文乱码我相信每一个程序员都会遇到这种问题。 但有时候我们按照网上教程去设置,确实编码好了,但是有时候按照教程来却没能达到我们的预期。 在此之前我将所有编码都设置成了UTF-8,文件编码,项目编码,尝试(最终不需要…

[实现Rpc] 客户端划分 | 框架设计 | common类的实现

目录 3. 客户端模块划分 3.1 Network模块 3.2 Protocol模块 3.3 Dispatcher模块 3.4 Requestor模块 3.5 RpcCaller模块 3.6 Publish-Subscribe模块 3.7 Registry-Discovery模块 3.8 Client模块 4. 框架设计 4.1 抽象层 4.2 具象层 4.3 业务层 ⭕4.4 整体设计框架…

Java里ArrayList和LinkedList有什么区别?

大家好,我是锋哥。今天分享关于【Java里ArrayList和LinkedList有什么区别?】面试题。希望对大家有帮助; Java里ArrayList和LinkedList有什么区别? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 ArrayList 和 LinkedL…

【Java】分布式锁Redis和Redisson

https://blog.csdn.net/weixin_44606481/article/details/134373900 https://www.bilibili.com/video/BV1nW421R7qJ Redis锁机制一般是由 setnx 命令实现,set if not exists,语法setnx key value,将key设置值为value,如果key不存在…

c++TinML转html

cTinML转html 前言解析解释转译html类定义开头html 结果这是最终效果(部分): ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6cf6c3e3c821446a84ae542bcc2652d4.png) 前言 在python.tkinter设计标记语言(转译2-html)中提到了将Ti…

2.2 反向传播:神经网络如何“学习“?

一、神经网络就像小学生 想象一个刚学算术的小学生,老师每天布置练习题,学生根据例题尝试解题,老师批改后指出错误。神经网络的学习过程与此相似: 输入层:相当于练习题(如数字图片)输出层&…