《昇思25天学习打卡营第6天 | 函数式自动微分》

《昇思25天学习打卡营第6天 | 函数式自动微分》

目录

  • 《昇思25天学习打卡营第6天 | 函数式自动微分》
    • 函数式自动微分
    • 简单的单层线性变换模型
    • 函数与计算图
    • 微分函数与梯度计算
    • Stop Gradient

函数式自动微分

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口grad和value_and_grad。

简单的单层线性变换模型

我们通过学习使用一个简单的单层线性变换模型来了解

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

compute-graph
在这个模型中, 𝑥
为输入, 𝑦
为正确值, 𝑤
和 𝑏
是我们需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

def function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss

执行计算函数,可以获得计算的loss值。

loss = function(x, y, w, b)
print(loss)

Tensor(shape=[], dtype=Float32, value= 0.914285)

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数: ∂ loss ⁡ ∂ w \frac{\partial \operatorname{loss}}{\partial w} wloss ∂ loss ⁡ ∂ b \frac{\partial \operatorname{loss}}{\partial b} bloss,此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对 w w w b b b求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))

执行微分函数,即可获得 𝑤
、 𝑏
对应的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]]),
Tensor(shape=[3], dtype=Float32, value= [ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]))

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]]),
Tensor(shape=[3], dtype=Float32, value= [ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]))

可以看到求得 w w w b b b对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/367754.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA每日作业day7.1-7.3小总结

ok了家人们前几天学了一些知识,接下来一起看看吧 一.API Java 的 API ( API: Application( 应用 ) Programming( 程序 ) Interface(接口 ) ) Java API 就是 JDK 中提供给我们使用的类,这些类将底层 的代码实现封装了起来&#x…

Linux多进程和多线程(四)进程间通讯-定时器信号和子进程退出信号

多进程(四) 定时器信号alarm()函数示例alarm()函数的限制定时器信号的实现原理setitimer()函数setitimer()和alarm()函数的区别 setitimer() old_value参数的示例 对比alarm()区别总结: 子进程退出信号 示例: 多进程(四) 定时器信号 SIGALRM 信号是用来通知进程…

新声创新20年:无线技术给助听器插上“娱乐”的翅膀

听力损失并非现代人的专利,古代人也会有听力损失。助听器距今发展已经有二百多年了,从当初单纯的声音放大器到如今的全数字时代助听器,助听器发生了翻天覆地的变化,现代助听器除了助听功能,还具有看电视,听…

微信小程序 调色板

注意:是在uniapp中直接使用的一个color-picker插件,改一下格式即可在微信小程序的原生代码中使用 https://github.com/KirisakiAria/we-color-picker 这是插件的地址,使用的话先把这个插件下载下来,找到src,在项目创…

FreeRTOS和UCOS操作系统使用笔记

FreeRTOS使用示例 UCOS使用示例 信号量使用 信号量访问共享资源区/ OS_SEMMY_SEM; //定义一个信号量,用于访问共享资源OSSemCreate ((OS_SEM* )&MY_SEM, //创建信号量,指向信号量(CPU_CHAR* )"MY_SEM", //信号量名字(OS_SEM_CTR )1, …

imx6ull/linux应用编程学习(8)PWM应用编程(基于正点)

1.应用层如何操控PWM: 与 LED 设备一样, PWM 同样也是通过 sysfs 方式进行操控,进入到/sys/class/pwm 目录下 这里列举出了 8 个以 pwmchipX(X 表示数字 0~7)命名的文件夹,这八个文件夹其实就对应了…

守护矿山安全生产:AI视频分析技术在煤矿领域的应用

随着人工智能(AI)技术的快速发展,其在煤矿行业的应用也日益广泛。AI视频智能分析技术作为其中的重要分支,为煤矿的安全生产、过程监测、效率提升和监管决策等提供了有力支持。 一、煤矿AI视频智能分析技术的概述 视频智慧煤矿AI…

数据库测试数据准备厂商 Snaplet 宣布停止运营

上周刚获知「数据库调优厂商 OtterTune 宣布停止运营」。而今天下班前,同事又突然刷到另一家海外数据库工具商 Snaplet 也停止运营了。Snaplet 主要帮助开发团队在数据库中生成仿真度高且合规的测试数据。我们在年初还撰文介绍过它「告别手搓!Postgres 一…

Continual Test-Time Domain Adaptation--论文笔记

论文笔记 资料 1.代码地址 https://github.com/qinenergy/cotta 2.论文地址 https://arxiv.org/abs/2203.13591 3.数据集地址 论文摘要的翻译 TTA的目的是在不使用任何源数据的情况下,将源预先训练的模型适应到目标域。现有的工作主要考虑目标域是静态的情况…

Vue项目打包上线

Nginx 是一个高性能的开源HTTP和反向代理服务器,也是一个IMAP/POP3/SMTP代理服务器。它在设计上旨在处理高并发的请求,是一个轻量级、高效能的Web服务器和反向代理服务器,广泛用于提供静态资源、负载均衡、反向代理等功能。 1、下载nginx 2、…

探讨命令模式及其应用

目录 命令模式命令模式结构命令模式适用场景命令模式优缺点练手题目题目描述输入描述输出描述题解 命令模式 命令模式是一种行为设计模式, 它可将请求转换为一个包含与请求相关的所有信息的独立对象。 该转换让你能根据不同的请求将方法参数化、 延迟请求执行或将其…

【高级篇】第9章 Elasticsearch 监控与故障排查

9.1 引言 在现代数据驱动的应用架构中,Elasticsearch不仅是海量数据索引和搜索的核心,其稳定性和性能直接影响到整个业务链路的健康度。因此,建立有效的监控体系和掌握故障排查技能是每一位Elasticsearch高级专家的必备能力。 9.2 监控工具:洞察与优化的利器 在Elastics…

Rough.js在Vue3中生成随机蒙德里安风格的抽象艺术

本文由ScriptEcho平台提供技术支持 项目地址:传送门 Mondrian风格艺术生成器:用Vue和RoughJS创造抽象艺术 应用场景 Mondrian风格艺术以其大胆的色彩块和简单的几何形状而闻名。这种风格可以应用于各种设计项目,包括海报、插图和网页设计…

基于Web技术的教育辅助系统设计与实现(SpringBoot MySQL)+文档

💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

Markdown+VSCODE实现最完美流畅写作体验

​下载VSCODE软件 安装插件 Markdown All in One :支持markdown的语言的; Markdown Preview Enhanced :观看写出来文档的效果; Paste IMage :添加图片的 Code Spell Checker检查英文单词错误; 基础语法 标题 #一个…

AI 会淘汰程序员吗?

前言 前些日子看过一篇文章,说国外一位拥有 19 年编码经验、会 100% 手写代码的程序员被企业解雇了,因为他的竞争对手,一位仅有 4 年经验、却善于使用 Copilot、GPT-4 的后辈,生产力比他更高,成本比他更低&#xff0c…

西南交通大学【算法分析与设计实验3】

实验3.3 任务分配问题 实验目的 (1)理解穷举法典型算法的求解过程。 (2)学习穷举法的时间复杂度分析方法,并通过实验验证算法的执行效率。 (3)学会如何利用穷举法求解具体问题,了…

按是否手工执行测试的角度划分:手工测试、自动化测试

1.手工测试(Manual testing) 手工测试是由人一个一个的输入用例,然后观察结果,和机器测试相对应,属于比较原始但是必须的一个步骤。 由专门的测试人员从用户视角来验证软件是否满足设计要求的行为。 更适用针对深度…

哈希表 | 哈希查找 | 哈希函数 | 数据结构 | 大话数据结构 | Java

🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 📌毛毛张今天分享的内容🖆是数据结构中的哈希表,毛毛张主要是依据《大话数据结构📖》的内容来进行整理,不…

程序化交易广告及其应用

什么是程序化交易广告? 程序化交易广告是以实时竞价技术即RTB(real-time bidding)为核心的广告交易方式。说到这里,你可能会有疑问:像百度搜索关键词广告还有百度网盟的广告,不也是CPC实时竞价的吗&#x…