自动微分autograd实践要点

目录

  • 定义Value
    • 手动定义每个 operator 的 `_backward()` 函数
    • 构建反向传播计算链

本文主要参考 反向传播和神经网络训练 · 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之一,提炼一些精要,将反向传播的细节和要点展现出来

定义Value

第一步首先要定义Value,算子中需包含:data(Value 的数值), grad(Value 的梯度),_backward(反向传播函数,初始化为 None),_prev(需要依赖于它的Value,用于后面构建反向传播链):

class Value:def __init__(self, data, _children=(), _op='', label=''):self.data = dataself.grad = 0.0self._backward = lambda: Noneself._prev = set(_children)self._op = _opself.label = labeldef __repr__(self):return f"Value(data={self.data})"def __add__(self, other):out = Value(self.data + other.data, (self, other), '+')def _backward():self.grad += 1.0 * out.gradother.grad += 1.0 * out.gradout._backward = _backwardreturn outdef __mul__(self, other):out = Value(self.data * other.data, (self, other), '*')def _backward():self.grad += other.data * out.gradother.grad += self.data * out.gradout._backward = _backwardreturn outdef tanh(self):x = self.datat = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)out = Value(t, (self, ), 'tanh')def _backward():self.grad += (1 - t**2) * out.gradout._backward = _backwardreturn outdef backward(self):topo = []visited = set()def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)self.grad = 1.0for node in reversed(topo):node._backward()

上述代码的核心在于:

  1. 每个算子的_backward 函数需要依次按算子进行手动定义
  2. 一个Valuebackward函数,是从当前Value开始,先将依赖于这个Value的所有Value按依赖顺序串起来,然后再从当前Value开始依次运行_backward()。这样只需要对一条链上的最后一个Value运行backward函数,就可以将这个链上的所有节点的grad全更新一次,即完成一次反向传播

例如:

a = Value(2.0, label='a')
b = Value(-3.0, label='b')
c = Value(10.0, label='c')
e = a*b; e.label = 'e'
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'

上述计算链条的图示为:
在这里插入图片描述


手动定义每个 operator 的 _backward() 函数

加法
不同算子的梯度值不一样,例如:对于加法:out = Value(self.data + other.data, (self, other), '+')outself. Data 求导,倒数值为1,因此其_backward() 函数定义为:

def _backward():self.grad += 1.0 * out.gradother.grad += 1.0 * out.grad

这里是 += 而不是 = 的原因是,有的时候某个node的Value,在前向传播时可能影响了不止一个Value,例如:

在这里插入图片描述
那么这里 a.grad 即要计算从 d 处来的反向传播,也要考虑从 e 处来的反向传播,因此是 +=

乘法
同理,对于乘法 out = Value(self.data * other.data, (self, other), '*')outself. Data 求导,倒数值为 other. Data,因此:

def _backward():self.grad += other.data * out.gradother.grad += self.data * out.grad

激活函数
tanh 激活函数为:
t a n h ( x ) = e x − e − x e x + e − x = e 2 x − 1 e 2 x + 1 tanh(x) = \frac{e^x -e^{-x}}{e^x + e^{-x}} = \frac{e^{2x} -1}{e^{2x} + 1} tanh(x)=ex+exexex=e2x+1e2x1
其倒数为 t a n h ′ ( x ) = 1 − ( e 2 x − 1 e 2 x + 1 ) 2 = 1 − ( t a n h ( x ) ) 2 tanh'(x) = 1 - (\frac{e^{2x} -1}{e^{2x} + 1})^2 = 1 - (tanh(x))^2 tanh(x)=1(e2x+1e2x1)2=1(tanh(x))2
因此对应的公式为

def tanh(self):x = self.datat = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)out = Value(t, (self, ), 'tanh')def _backward():self.grad += (1 - t**2) * out.grad

构建反向传播计算链

以上图为例:

在这里插入图片描述
反向传播就是从 L 开始从右往左依次调用各个Node的 _backward(),因此链条构建的方式类似于树的遍历,从根节点开始往逐渐添加叶节点:

def backward(self):topo = []visited = set()def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)self.grad = 1.0for node in reversed(topo):node._backward()

这里根节点对自身进行求导,倒数值都为1,所以需要设置self.grad = 1.0。最后只需运行一次 L.backward() 就可以把所有Node的梯度全更新一遍,以下是运行一次L.backward()后的结果:

在这里插入图片描述


Reference:

  1. 反向传播和神经网络训练 · 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之一

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/405691.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

传知代码-自动化细胞核分割与特征分析(论文复现)

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 引言 细胞核分割和分类在医学研究和临床诊断中具有重要意义。精准的细胞核分割能够帮助医生更好地识别和分析细胞核的形态学特征,从而辅助疾病诊断、癌症检测以及药物研发。HoverNet是一种基于深度学…

【GitLab】使用 Docker engine安装 GitLab 2: gitlab-ce:17.3.0-ce.0 拉取

ce版本必须配置代理。 极狐版本可以直接pull 社区版GitLab不支持Alibaba Cloud Linux 3,本操作以Ubuntu/Debian系统为例进行说明,其他操作系统安装说明,请参见安装社区版GitLab。 docker 环境重启 sudo systemctl daemon-reload sudo systemctl restart docker脚本安装 安裝…

苹果手机微信聊天记录删除了怎么恢复?

在日常使用手机的过程中,我们经常会遇到误删微信聊天记录的情况,尤其是对于那些重要的对话记录,一旦丢失可能会带来不小的困扰。今天,我们就来探讨一下如何在苹果手机上恢复被删除的微信聊天记录。 一、利用第三方数据恢复工具 对…

拓客工具,助你多维度筛选客源!

随着大数据与人工智能技术的飞速发展,企业拓客的方式也迎来了前所未有的变革。在这里将分享如何利用拓客工具,在任意行业中精实现高效拓客。 一、高级搜索:最新企业! 传统的客户开发方式往往依赖于广撒网式的营销手段,…

Qt实现tcp协议

void Widget::readyRead_slot() {//读取服务器发来的数据QByteArray msg socket->readAll();QString str QString::fromLocal8Bit(msg);QStringList list str.split(:);if(list.at(0) userName){QString str2;for (int i 1; i < list.count(); i) {str2 list.at(i);…

作业8/21

client cpp #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget), socket(new QTcpSocket(this)) {ui->setupUi(this); // 设置 UI 界面// 控件初始状态设置为禁用&#xff0c;防止未连接…

Linux timedatectl 命令

timedatectl 是 Linux 系统中用于查询和更改系统日期、时间和时区的工具&#xff0c;它特别适用于那些使用 systemd 作为系统和服务管理器的系统。语法格式为“timedatectl [参数]”。 发现电脑时间跟实际时间不符&#xff1f;如果只差几分钟的话&#xff0c;我们可以直接调整。…

tekton通过ceph挂载node_modules的时候报错failed to execute command: copying dir: symlink

分析&#xff1a; 如果ceph的mountPath和workingDir路径一致的话&#xff0c;就会报错。 解决&#xff1a;node_modules挂载到/workspace下&#xff0c;workingDir的代码mv到/workspace下进行构建。

MyBatis-Plus与PageHelper依赖的jsqlparser库冲突

问题 最近遇到的一个项目升级了SpringBoot到3.x版本了&#xff0c;同时也准备升级MyBatis-Plus&#xff0c;即使用如下依赖&#xff1a; <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><…

初级python代码编程学习----简单的图形化闹钟小程序

我们来创建一个简单的图形化闹钟程序通常需要使用图形用户界面&#xff08;GUI&#xff09;库。以下是使用Python的Tkinter库创建一个基本闹钟程序的步骤&#xff1a; 环境准备 确保已安装Python。安装Tkinter库&#xff08;Python 3.8及以上版本自带Tkinter&#xff0c;无需…

【代码】Swan-Transformer 代码详解(待完成)

1. 局部注意力 Window Attention (W-MSA Module) class WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number…

用户画像实时标签数据处理流程图

背景 在用户画像中&#xff0c;有一类实时标签&#xff0c;我们既要它能够实时的对外提供数据统计&#xff0c;也要保存到大数据组件中用于后续的对数&#xff0c;圈选的逻辑&#xff0c;本文就看一下用户画像的实时标签的数据流转图 实时标签数据流转图 首先我们肯定是要使…

独立站PrestaShop安装

独立站PrestaShop安装 独立站PrestaShop安装系统需求下载PrestaShop浏览器下载命令行下载 解压PrestaShop创建数据库移动PrestaShop源码到web目录composer安装依赖包nginx配置访问域名进入安装页面选择语言许可协议系统兼容性店铺信息Content of your store系统配置数据库店铺安…

书生大模型学习笔记9 - LMDeploy 量化部署

LMDeploy 量化部署 InternLM 2.5 20b量化前部署W4A16 模型量化量化模型部署streamlit web InternLM 2.5 20b量化前部署 lmdeploy serve api_server \/root/learning/InternLM/XTuner/merged_20b \--model-format hf \--quant-policy 0 \--cache-max-entry-count 0.01\--server…

数据结构与算法——图

1、为什么要有图 1&#xff09;前面我们学习了线性表和树 2&#xff09;线性表局限于一个直接前驱和一个直接后继的关系 3&#xff09;树也只能有一个直接前驱就是父节点 4&#xff09;当我们需要表示多对多的关系时&#xff0c;这里我们就用到了图 图是一种数据结构&#xf…

支持2.4G频秒变符合GB42590的标准的飞行器【无人机GB42590发射端】

使用方法: 放在飞机 上&#xff0c;按键那一面需要朝上对着天空(因为GPS陶瓷天线在按键面)&#xff0c;支持基本ID&#xff0c;向量和系统包&#xff0c;电池容量240mAH充电1小时&#xff0c;使用时间大概2小时。 1.长按3秒开关机 2.开机红灯慢闪&#xff0c;只发射基本ID数据…

JavaScript_7_练习:随机抽奖案例

效果图 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>练习&#xff1a;随机抽奖案例</tit…

【后续更新】python搜集上海二手房数据

源码如下: import asyncio import aiohttp from lxml import etree import logging import datetime import openpyxlwb = openpyxl.Workbook() sheet = wb.active sheet.append([房源, 房子信息, 所在区域, 单价, 关注人数和发布时间, 标签]) logging.basicConfig(level=log…

GD32双路CAN踩坑记录

GD32双路CAN踩坑记录 目录 GD32双路CAN踩坑记录1 问题描述2 原因分析3 解决办法4 CAN配置参考代码 1 问题描述 GD32的CAN1无法进入接收中断&#xff0c;收不到数据。 注&#xff1a;MCU使用的是GD32E50x&#xff0c;其他型号不确定是否一样&#xff0c;本文只以GD32E50x举例说…

CTF中的换表类Crypto题目

目录 [安洵杯 2019]JustBase[SWPUCTF 2021 新生赛]traditional字符替换解密 [BJDCTF 2020]base??字符替换 --》 base64解密 [安洵杯 2019]JustBase VGhlIGdlbxvZ#kgbYgdGhlIEVhcnRoJ#Mgc#VyZmFjZSBpcyBkb!pbmF)ZWQgYnkgdGhlIHBhcnRpY#VsYXIgcHJvcGVydGllcyBvZiB#YXRlci$gUHJ…