《昇思25天学习打卡营第7天|函数式自动微分》

文章目录

  • 今日所学:
  • 一、函数与计算图
  • 二、微分函数与梯度计算
  • 三、Stop Gradient
  • 四、Auxiliary data
  • 五、神经网络梯度计算
  • 总结


今日所学:

今天我学习了神经网络训练的核心原理,主要是反向传播算法。这个过程包括将模型预测值(logits)和正确标签(label)输入到损失函数(loss function)中计算loss,然后通过反向传播算法计算梯度(gradients),最终更新模型参数(parameters)。自动微分技术能够在某点计算可导函数的导数值,是反向传播算法的一个广义实现。它的主要作用是将复杂的数学运算分解为一系列简单的基本运算,从而屏蔽了大量求导的细节和过程,显著降低了使用深度学习框架的门槛。

MindSpore采用函数式自动微分的设计理念,提供了更接近数学语义的自动微分接口,例如grad和value_and_grad。为了更好地理解这些概念,我还学习了如何使用一个简单的单层线性变换模型进行实践。


一、函数与计算图

MindSpore之前的还不熟悉的相关内容可以见:《昇思25天学习打卡营第1天|基本介绍》

计算图是一种借助图论来描绘数学函数的一种方法,同时也是深度学习框架用以表达神经网络模型的通用方式。以下,我们将以此计算图为基础,来构建计算函数和神经网络:
在这里插入图片描述
在本节所学的这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数,根据计算图描述的计算过程,构造计算函数,执行计算函数,可以获得计算的loss值,代码与结果如下所示:

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return lossloss = function(x, y, w, b)
print(loss)

结果如下:

Tensor(shape=[], dtype=Float32, value= 0.914285)

二、微分函数与梯度计算

在之后学习内容中为了优化模型参数,需要求参数对loss的导数:

∂loss∂𝑤

∂loss∂𝑏

此时我们调用mindspore.grad函数,来获得function的微分函数。其中grad函数的两个入参,分别为fn(待求导的函数)与grad_position(指定求导输入位置的索引),代码如下:

grad_fn = mindspore.grad(function, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

结果如下:

在这里插入图片描述

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

三、Stop Gradient

在常规的情况下,求导操作主要是计算loss相对于参数的导数,由此,函数的输出仅有loss一项。然而,当我们期望函数有多项输出时,微分函数将会计算所有输出项相对于参数的导数。在这种情况下,如果我们希望实现特定输出项的梯度截断,或者需要消除某个Tensor对梯度的影响,那么我们将需要使用Stop Gradient操作。在这里,我们会将function改造成同时输出loss和z的function_with_logits,并获取微分函数以供执行。

如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。

代码如下:

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, zgrad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

截断前结果:
在这里插入图片描述

截断后结果:

在这里插入图片描述

四、Auxiliary data

我深入理解了Auxiliary data(辅助数据)的概念和应用。我明白了辅助数据其实就是函数的非主要输出项。在实际应用中,我们常将函数的主要输出设为loss,而其它的所有输出则被视为辅助数据。对于grad和value_and_grad函数,我享受到了has_aux参数带来的便利。当将其设为True,它就能自动实现之前需要手动添加的stop_gradient操作。这种设计巧妙地使我在返回辅助数据的同时,不受梯度计算的任何影响。

在后续的实践中,我继续使用了function_with_logits,并设置了has_aux为True进行操作。整个过程顺畅无比,加深了我对这一主题的理解。我会持续探索,并将这些知识应用到更广泛的场景中去:

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

结果如下:

在这里插入图片描述

五、神经网络梯度计算

前面章节已经讲述了网络构建,还不了解的可见这篇文章:《昇思25天学习打卡营第6天|网络构建》

接下来,我深入了解了如何通过Cell去构造神经网络,以及利用函数式自动微分来实现反向传播的过程。我首先继承了nn.Cell来构建单层线性变换神经网络。有意思的是,这个过程中我直接使用了之前的 𝑤 和 𝑏 来作为模型参数。这种做法完全打破了我早前的理解,让我认识到原来我们可以直接使用现有的参数以节约时间和计算资源。我将这些参数用mindspore.Parameter包装起来作为内部属性,并在construct内实现了与之前一样的Tensor操作:

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return lossgrad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

结果如下:
在这里插入图片描述

可以看出,执行微分函数后的梯度值和前文function求得的梯度值一致。

在这里插入图片描述

总结

在今天的学习中,我深入理解了神经网络训练的核心原理,包括反向传播算法和如何利用自动微分技术来计算梯度并更新模型参数。我也学习了如何使用MindSpore框架的函数式自动微分接口来进行实践,并利用计算图进行模型参数优化。此外,我理解了Stop Gradient操作和辅助数据对梯度计算的影响,以及如何在神经网络的梯度计算中有效利用它们。通过理论学习和实践操作,我对这些概念有了更深入的理解,期待在明天的学习中继续进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/367856.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【设计模式】行为型-状态模式

在变幻的时光中,状态如诗篇般细腻流转。 文章目录 一、可调节的灯光二、状态模式三、状态模式的核心组件四、运用状态模式五、状态模式的应用场景六、小结推荐阅读 一、可调节的灯光 场景假设:我们有一个电灯,它可以被打开和关闭。用户可以…

大模型与机器人精彩碰撞-7月5日晚上八点不见不散!

在瞬息万变的科技时代,新兴人工智能和机器人技术的结合正在引领新一轮的创新浪潮。你是否想成为未来科技的领航者?你是否想了解最前沿的AI与机器人技术?行麦科技重磅推出的“AIGC时代的生存法则”AI系列课,将为你揭开大模型与机器…

redis学习(001 介绍)

黑马程序员Redis入门到实战教程,深度透析redis底层原理redis分布式锁企业解决方案黑马点评实战项目 总时长 42:48:00 共175P 此文章包含第1p-第p4的内容 文章目录 介绍差异对比事务区别 认识redis 介绍 两种键值对方式对比 差异对比 事务区别 认识redis

Paimon 在汽车之家的业务实践

汽车之家基于Paimon的实践 摘要:本文分享自汽车之家的王刚、范文、李乾⽼师。介绍了汽车之家基于 Paimon 的一些实践,和一些背景。内容主要为以下四部分: 一、背景 二、业务实践 三、paimon 优化实践 四、未来规划 一、背景 在使用Paimon之前…

Java---Mybatis详解二

雄鹰展翅凌空飞, 大江奔流不回头。 壮志未酬心未老, 豪情万丈任遨游。 巍巍高山攀顶峰, 滔滔黄河入海流。 风云变幻凭君舞, 踏遍天涯尽逍遥。 目录 一,环境准备 二,删除 三,删除(预编译SQL) 为什…

无法定位程序输入点Z9 qt assertPKcS0i于动态链接库F:\code\projects\06_algorithm\main.exe

解决方法: 这个报错,是因为程序在运行时没要找到所需的dll库,如果把这个程序方法中对应库的目录下执行,则可正常执行。即使将图中mingw_64\bin 环境变量上移到msvc2022_64\bin 之前也不可以。 最终的解决方法是在makefile中设置环…

RealMAN:大规模真实录制且经过注释的麦克风阵列数据集

在深度学习驱动的多通道语音增强和声源定位系统的开发中,由于缺乏大规模的真实录制数据集,这些系统的训练在很大程度上依赖于房间脉冲响应(RIR)和多通道扩散噪声的模拟。然而,模拟数据和真实世界数据之间存在的声学失配…

QStringListModel 绑定到QListView

1.QStringListModel 绑定到listView,从而实现MV模型视图 2.通过QStringListModel的新增、删除、插入、上下移动,listView来展示出来 3.下移动一行,传入curRow2 的个人理解 布局 .h声明 private:QStringList m_strList;QStringListModel *m_m…

[译]Reactjs性能篇

英文有限,技术一般,海涵海涵,由于不是翻译出身,所以存在大量的瞎胡乱翻译的情况,信不过我的,请看原文~~ 原文地址:https://facebook.github.io/react/docs/advanced-per…

Servlet_Web小结

1.web开发概述 什么是服务器? 解释一:服务器就是一款软件,可以向其发送请求,服务器会做出一个响应. 可以在服务器中部署文件,让他人访问 解释二:也可以把运行服务器软件的计算机也可以称为服务器。 web开发: 指的是从网页中向后…

Android LayoutInflater 深度解析

在 Android 开发中,LayoutInflater 是一个非常重要的工具。它允许我们从 XML 布局文件中动态地创建 View 对象,从而使得 UI 的创建和管理更加灵活。本文将深入解析 android.view.LayoutInflater,包括它的基本用法、常见问题以及高级用法。 什…

idea xml ctrl+/ 注释格式不对齐

处理前 处理后 解决办法 取消这两个勾选

【UE5.3】笔记6-创建可自由控制Pawn类

搭建场景 搭建一个场景:包含地板、围墙。可以根据喜好加一些自发光的效果。 增加食物 创建食物蓝图类,在场景里放置一些食物以供我们player去吃掉获取分值。 创建可控制的layer 我们先右键创建一个蓝图继承自pawn类,起名BP_Player&#xf…

深度学习之半监督学习:一文梳理目标检测中的半监督学习策略

什么是半监督目标检测? 传统机器学习根据训练数据集中的标注情况,有着不同的场景,主要包括:监督学习、弱监督学习、弱半监督学习、半监督学习。由于目标检测任务的特殊性,在介绍半监督目标检测方法之前,我…

视频融合共享平台LntonCVS统一视频接入平台智慧安防应用方案

安防视频监控平台LntonCVS是一款拥有强大拓展性和灵活部署能力的综合管理平台。它支持多种主流标准协议,包括国标GB28181、RTSP/Onvif、RTMP等,同时兼容各厂家的私有协议和SDK,如海康Ehome、海大宇等。LntonCVS不仅具备传统安防视频监控功能&…

PHP电商系统开发指南最佳实践

电子商务系统开发的最佳实践包括:数据库设计:选择适合关系型数据库,优化数据结构,考虑表分区;安全:加密数据,防止 sql 注入,处理会话管理;用户界面:遵循 ux 原…

mysql-sql-第十四周

学习目标: sql 学习内容: 40.查询学过「哈哈」老师授课的同学的信息 Select * from students left join score on students.stunmscore.stunm where counm (select counm from teacher left join course on teacher.teanmcourse.teanm where teache…

【深度学习】Transformer

李宏毅深度学习笔记 https://blog.csdn.net/Tink1995/article/details/105080033 https://blog.csdn.net/leonardotu/article/details/135726696 https://blog.csdn.net/u012856866/article/details/129790077 Transformer 是一个基于自注意力的序列到序列模型,与基…

伺服调试三环讲解

在伺服调试过程中,有些项目要求不高,采用伺服自整定就可以调试好伺服,但有些项目对伺服有着比较高的要求,于是需要采取手动调试伺服参数,下面就介绍一下伺服三环参数的调试的方法。 三环指:电流环、速度环、位置环 带宽关系:电流环带宽>速度环带宽>位置环带宽 三环控…

C语言单链表的算法之插入节点

一:访问各个节点中的数据 (1)访问链表中的各个节点的有效数据,这个访问必须注意不能使用p、p1、p2,而只能使用phead (2)只能用头指针不能用各个节点自己的指针。因为在实际当中我们保存链表的时…