昇思MindSpore进阶教程-优化器

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

模型训练过程中,使用优化器更新网络参数,合适的优化器可以有效减少训练时间,提高模型性能。

最基本的优化器是随机梯度下降算法(SGD),很多优化器在SGD的基础上进行了改进,以实现目标函数能更快速更有效地收敛到全局最优点。MindSpore中的nn模块提供了常用的优化器,如nn.SGD、nn.Adam、nn.Momentum等。本章主要介绍如何配置MindSpore提供的优化器以及如何自定义优化器。

在这里插入图片描述

nn.optim

配置优化器

参数配置

在构建优化器实例时,需要通过优化器参数params配置模型网络中要训练和更新的权重。Parameter中包含了一个requires_grad的布尔型的类属性,用于表示模型中的网络参数是否需要进行更新。

网络中大部分参数的requires_grad默认值为True,少部分默认值为False,例如BatchNorm中的moving_mean和moving_variance。

MindSpore中的trainable_params方法会屏蔽掉Parameter中requires_grad为False的属性,在为优化器配置 params 入参时,可使用net.trainable_params()方法来指定需要优化和更新的网络参数。

import numpy as np
import mindspore
from mindspore import nn, ops
from mindspore import Tensor, Parameterclass Net(nn.Cell):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 6, 5, pad_mode="valid")self.param = Parameter(Tensor(np.array([1.0], np.float32)), 'param')def construct(self, x):x = self.conv(x)x = x * self.paramout = ops.matmul(x, x)return outnet = Net()# 配置优化器需要更新的参数
optim = nn.Adam(params=net.trainable_params())
print(net.trainable_params())

用户可以手动修改网络权重中 Parameter 的 requires_grad 属性的默认值,来决定哪些参数需要更新。

如下例所示,使用 net.get_parameters() 方法获取网络中所有参数,并手动修改巻积参数的 requires_grad 属性为False,训练过程中将只对非卷积参数进行更新。

conv_params = [param for param in net.get_parameters() if 'conv' in param.name]
for conv_param in conv_params:conv_param.requires_grad = False
print(net.trainable_params())
optim = nn.Adam(params=net.trainable_params())
学习率

学习率作为机器学习及深度学习中常见的超参,对目标函数能否收敛到局部最小值及何时收敛到最小值有重要影响。学习率过大容易导致目标函数波动较大,难以收敛到最优值,太小则会导致收敛过程耗时过长。除了设置固定学习率,MindSpore还支持设置动态学习率,这些方法在深度学习网络中能明显提升收敛效率。
固定学习率:
使用固定学习率时,优化器传入的learning_rate为浮点类型或标量Tensor。

以nn.Momentum为例,固定学习率为0.01,示例如下:

# 设置学习率为0.01
optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)

动态学习率:
mindspore.nn提供了动态学习率的模块,分为Dynamic LR函数和LearningRateSchedule类。其中Dynamic LR函数会预先生成长度为total_step的学习率列表,将列表传入优化器中使用,训练过程中,第i步使用第i个学习率的值作为当前step的学习率,其中total_step的设置值不能小于训练的总步数;LearningRateSchedule类将实例传递给优化器,优化器根据当前step计算得到当前的学习率。

运行中修改优化器参数

运行中修改学习率

mindspore.experimental.optim.Optimizer 中学习率为 Parameter,除通过上述动态学习率模块 mindspore.experimental.optim.lr_scheduler 动态修改学习率,也支持使用 assign 赋值的方式修改学习率。

例如下述样例,在训练step中,设置如果损失值相比上一个step变化小于0.1,将优化器第1个参数组的学习率调整至0.01:

net = Net()
loss_fn = nn.MAELoss()
optimizer = optim.Adam(net.trainable_params(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
last_step_loss = 0.1def forward_fn(data, label):logits = net(data)loss = loss_fn(logits, label)return lossgrad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)if ops.abs(loss - last_step_loss) < 0.1:ops.assign(optimizer.param_groups[1]["lr"], Tensor(0.01))return loss

运行中修改除lr以外的优化器参数

下述样例,在训练step中,设置如果损失值相比上一个step变化小于0.1,将优化器第1个参数组的 weight_decay 调整至0.02:

net = Net()
loss_fn = nn.MAELoss()
optimizer = optim.Adam(net.trainable_params(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
last_step_loss = 0.1def forward_fn(data, label):logits = net(data)loss = loss_fn(logits, label)return lossgrad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)if ops.abs(loss - last_step_loss) < 0.1:optimizer.param_groups[1]["weight_decay"] = 0.02return loss

自定义优化器

与上述自定义优化器方式相同,自定义优化器时也可以继承优化器基类experimental.optim.Optimizer,并重写__init__方法和construct方法以自行设定参数更新策略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/432236.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BACnet MS/TP协议解析(3)

1、MS/TP帧格式 例如数据&#xff08;hex&#xff09;&#xff1a;55 FF 01 03 02 00 00 D7 0x550xff0x010x030x020x000x000xD7BACnet数据BACnet数据CRC帧头帧类型目的地址源地址BACnet数据长度&#xff0c;大端CRC 2、帧类型 帧类型目前定义为 0-7&#xff0c;8-127 为 AS…

【Unity踩坑】Textmesh Pro是否需要加入Version Control?

问题&#xff1a;如果Unity 项目中用到了Textmesh pro&#xff0c;相关的文件是否也需要签入呢&#xff1f; 回答&#xff1a; 在使用 Unity 的 Version Control&#xff08;例如 Plastic SCM 或 Git&#xff09;时&#xff0c;如果你的项目中使用了 TextMesh Pro&#xff0c…

TCN预测 | MATLAB实现TCN时间卷积神经网络多输入单输出回归预测

TCN预测 | MATLAB实现TCN时间卷积神经网络多输入单输出回归预测 目录 TCN预测 | MATLAB实现TCN时间卷积神经网络多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料预测效果

武汉正向科技 格雷母线检测方式 :车检,地检

正向科技|格雷母线原理运用-车检&#xff0c;地检 地上检测方式 地址编码器和天线箱安装在移动站上&#xff0c;通过天线箱发射地址信号&#xff0c;地址解码器安装在固定站&#xff08;地面&#xff09;上&#xff0c;在固定站完成地址检测。 车上检测方式 地址编码器安装在…

MySQL Mail服务器集成:如何配置发送邮件?

MySQL Mail插件使用指南&#xff1f;怎么优化 MySQL发邮件性能&#xff1f; MySQL Mail服务器的集成&#xff0c;使得数据库可以直接触发邮件发送&#xff0c;极大地简化了应用架构。AokSend将详细介绍如何配置MySQL Mail服务器&#xff0c;以实现邮件发送功能。 MySQL Mail&…

SegFormer网络结构的学习和重构

因为太多的博客并没有深入理解,本文是自己学习后加入自己深入理解的总结记录&#xff0c;方便自己以后查看。 segformer中encoder、decoder的详解。 学习前言 一起来学习Segformer的原理,如果有用的话&#xff0c;请记得点赞关注哦。 一、Segformer的网络结构图 网络结构&…

JavaWeb 12.Tomcat10

希望明天能出太阳 或者如果没有太阳的话 希望我能变得更加阳光一点 —— 24.9.25 一、常见的JavaWeb服务器 Web服务器通常由硬件和软件共同构成 硬件&#xff1a;电脑&#xff0c;提供服务供其他客户电脑访问 软件&#xff1a;电脑上安装的服务器软件&#xff0c;安装后能提…

TIOBE 编程指数 9 月排行榜公布 VB.Net第七

原文地址&#xff1a;百度安全验证 IT之家 9 月 8 日消息&#xff0c;TIOBE 编程社区指数是一个衡量编程语言受欢迎程度的指标&#xff0c;评判的依据来自世界范围内的工程师、课程、供应商及搜索引擎&#xff0c;今天 TIOBE 官网公布了 2024 年 9 月的编程语言排行榜&#xf…

介绍 Agent Q:迎接下一代 AI 自动化助手

引言 在科技领域&#xff0c;随着人工智能的不断进步&#xff0c;自动化工具日益成为提升效率的重要手段。今天&#xff0c;我将向大家介绍一款名为 Agent Q 的 AI 自动化助手。这款工具不仅能够完成复杂的任务&#xff0c;还支持交互式命令行操作&#xff0c;使得用户体验更为…

飞驰云联亮相电子半导体数智化年会 获”数据交换领域最佳厂商”

2024年9月20日&#xff0c;“2024第二届电子半导体/智能制造数智化年会暨品牌出海论坛”于上海隆重开幕&#xff0c;Ftrans飞驰云联作为国内领先的数据安全交换厂商&#xff0c;应邀携半导体全场景产品和解决方案亮相此次峰会。会上进行了“智象奖”评选&#xff0c;Ftrans飞驰…

java并发之并发关键字

并发关键字 关键字一&#xff1a;volatile 可以这样说&#xff0c;volatile 关键字是 Java 虚拟机提供的轻量级的同步机制。 功能 volatile 有 2 个主要功能&#xff1a; 可见性。一个线程对共享变量的修改&#xff0c;其他线程能够立即得知这个修改。普通变量不能做到这一点&a…

从零开始学习PX4源码5(遥控器数据)

#目录 文章目录 摘要1.PX4 遥控器控制整体流程2.PX4 遥控器输入程序3.PX4 遥控器数据外部调用接口4.PX4 遥控器手动(姿态控制)变量5.遥控器数据整体流程摘要 本节主要记录PX4代码中如何获取遥控器数据,遥控器数据如何被外界调用的过程,欢迎批评指正。 1.PX4 遥控器控制整…

JAVA-StringBuilder和StringBuffer

一、认识String类 1.认识 String在Java中是字符串类型&#xff0c;但与其他类型不同。它是一个类&#xff0c;可以创建对象的类。与int、char等自待类型有些许不同。但它仍然是java提供的一种类型。 类中有4个属性&#xff0c;这里主要认识一下value属性。它是实际存放字符串…

2024 IDEA软件 部署tomcat 十二步 运行web页面(html类似的)(中英文对照版本)新手小白易上手

目录 一、准备工作&#xff08;三必备&#xff09;&#xff1a; 1、自己的web项目 2、idea软件&#xff08;我是2023.1.2版本&#xff09; 3、tomcat X.X版本 二 、正式开始步骤&#xff0c;不废话&#xff01;&#xff01; 1、 点击菜单栏中 “File”&#xff08;文件&…

NASA数据集:ATLAS/ICESat-2 L3A 海洋地表高度 V006

ATLAS/ICESat-2 L3A Ocean Surface Height V006 目录 简介 摘要 代码 引用 网址推荐 0代码在线构建地图应用 机器学习 简介 该数据集&#xff08;ATL12&#xff09;包含全球开阔洋&#xff08;包括无冰季节冰区和近海岸地区&#xff09;的沿轨海面高度。 还提供了高度…

C++ 9.25

手动实现栈、和队列 stack #include <iostream> using namespace std; class Stack { private: int* arr; // 存储栈元素的数组 int top; // 栈顶索引 int capacity; // 栈的容量 public: Stack(int size) { arr new int[size]; c…

FLStudio21Mac版flstudio v21.2.1.3430简体中文版下载(含Win/Mac)

给大家介绍了许多FL21版本&#xff0c;今天给大家介绍一款FL Studio21Mac版本&#xff0c;如果是Mac电脑的朋友请千万不要错过&#xff0c;当然我也不会忽略掉Win系统的FL&#xff0c;链接我会放在文章&#xff0c;供大家下载与分享&#xff0c;如果有其他问题&#xff0c;欢迎…

基于Python大数据的音乐推荐及数据分析可视化系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

【吊打面试官系列-MySQL面试题】MySQL 数据库作发布系统的存储,一天五万条以上的增量,预计运维三年,怎么优化?

大家好&#xff0c;我是锋哥。今天分享关于【MySQL 数据库作发布系统的存储&#xff0c;一天五万条以上的增量&#xff0c;预计运维三年,怎么优化&#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; MySQL 数据库作发布系统的存储&#xff0c;一天五万条以上的增量…

二模--解题--101-110

文章目录 10.沟通管理101、 [单选] 在项目执行阶段&#xff0c;项目经理意识到项目干系人一直延迟答复敏感性电子邮件。项目经理应该怎么做&#xff1f; 4.整合管理102、 [单选] 在编制项目章程用于批准时&#xff0c;项目经理发现有两名干系人对关键可交付成果的期望有冲突。若…