神经网络反向传播算法

今天我们来看一下神经网络中的反向传播算法,之前介绍了梯度下降与正向传播~       神经网络的反向传播

专栏:💎实战PyTorch💎

反向传播算法(Back Propagation,简称BP)是一种用于训练神经网络的算法。 

反向传播算法是神经网络中非常重要的一个概念,它由Rumelhart、Hinton和Williams于1986年提出。这种算法基于梯度下降法来优化误差函数,利用了神经网络的层次结构来有效地计算梯度,从而更新网络中的权重和偏置。

基本工作流程:

  1. 通过正向传播得到误差,所谓正向传播指的是数据从输入到输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
  2. 通过反向传播把误差传递给模型的参数,从而对网络参数进行适当的调整,缩小预测值和真实值之间的误差。
  3. 反向传播算法是利用链式法则进行梯度求解,然后进行参数更新。对于复杂的复合函数,我们将其拆分为一系列的加减乘除或指数,对数,三角函数等初等函数,通过链式法则完成复合函数的求导。

我们通过一个例子来简单理解下 BP 算法进行网络参数更新的过程🧧:

如图我们在最下边输入两个维度的值进入神经网络:0.05、0.1 ,经过两个隐藏层(每层两个神经元),每个神经元有两个值,左边为输入值,右边是经过激活函数后的输出值;经过这个神经网络后的输出值为:m1、m2,实际值为0.01、0.99 🌠

设置的初始权重w1,w2,...w8分别为0.15、0.20、0.25、0.30、0.30、0.35、0.55、0.60

我们通过计算得到损失函数Error = 1/2 ((m1- target1)2 + (m2 - target2)2) = 0.2988

w5和w7均可以通过求三次导来求梯度,而w1,w3则不能直接通过L降序求导,我们需要求从L到m1,m1到o1,o1到k1,k1到h1,h1到w1:

由于w1是输出两个方向分别到o1和o2,所以是两个方向的梯度求和~

我们也发现所以激活函数都是要可微的~

其他的网络参数更新过程和上面的求导过程是一样的,这里就不过多赘述,我们直接看一下代码。

反向传播代码 

我们先来回顾一些Python中类的一些小细节:

🌈在Python中,使用super()函数可以调用父类的方法。这在子类中重写父类方法时非常有用,因为它允许你调用父类的实现,而不是完全覆盖它

class Parent:def __init__(self):print("Parent init")class Child(Parent):def __init__(self):super().__init__()print("Child init")c = Child()# 输出
Parent init
Child init

🌈当我们创建一个Child类的实例时,它会首先调用Parent类的__init__方法(通过super().__init__()),然后执行Child类的__init__方法,与类的__init__方法(构造方法)对应的类关闭时自动调用的方法是__del__方法。对象不再被使用时,Python解释器会自动调用这个方法。通常在这个方法中进行一些清理工作,比如释放资源、关闭文件等。

反向传播实现

import torch
import torch.nn as nn
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(2, 2)self.linear2 = nn.Linear(2, 2)# 网络参数初始化w1/w2/w3/w4self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])# w5/w6/w7/w8self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])# 截距bself.linear1.bias.data = torch.tensor([0.35, 0.35])self.linear2.bias.data = torch.tensor([0.60, 0.60])# 定义前向传播的行径def forward(self, x):x = self.linear1(x)x = torch.sigmoid(x)x = self.linear2(x)x = torch.sigmoid(x)return xif __name__ == '__main__':inputs = torch.tensor([[0.05, 0.10]])target = torch.tensor([[0.01, 0.99]])# 获得网络输出值net = Net()output = net(inputs)# print(output)  # tensor([[0.7514, 0.7729]], grad_fn=<SigmoidBackward>)# 计算误差loss = torch.sum((output - target) ** 2) / 2# print(loss)  # tensor(0.2984, grad_fn=<DivBackward0>)# 优化方法optimizer = optim.SGD(net.parameters(), lr=0.5)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 打印 w5、w7、w1 的梯度值print(net.linear1.weight.grad.data)# tensor([[0.0004, 0.0009],#         [0.0005, 0.0010]])print(net.linear2.weight.grad.data)# tensor([[ 0.0822,  0.0827],#         [-0.0226, -0.0227]])# 打印网络参数optimizer.step()print(net.state_dict())# OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996], [0.2498, 0.2995]])),#              ('linear1.bias', tensor([0.3456, 0.3450])),#              ('linear2.weight', tensor([[0.3589, 0.4087], [0.5113, 0.5614]])),#              ('linear2.bias', tensor([0.5308, 0.6190]))])
  • optimizer.step() 相当于是将w和b所有参数更新一步的过程

🌈关于nn.Linear的使用

import torch
import torch.nn.functional as F
import torch.nn as nn# 均匀分布随机初始化linear = nn.Linear(5, 3)
# 从0-1均匀分布产生参数
nn.init.uniform_(linear.weight)
print(linear.weight.data)

nn.Linear是PyTorch中用于创建线性层的类,也被称为全连接层。它的主要作用是将输入数据与权重矩阵相乘并加上偏置,然后通常会通过一个非线性激活函数进行转换。 

  1. 在函数内部,创建一个线性层,输入维度为5,输出维度为3;
  2. 使用nn.init.uniform_()函数对线性层的权重进行均匀分布随机初始化;
  3. 打印线性层的权重数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/317086.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM支持的可配置参数查看和分类

JVM参数大致可以分为三类: 标注指令:-开头。 这些是所有的HotSpot都支持的参数。可以用java-help 打印出来。 非标准指令: -X开头。 这些指令通常是跟特定的HotSpot版本对应的。可以用java -X打印出来。 不稳定参数: -XX 开头。 这一类参数是跟特定HotSpot版本对应的&#x…

[Java、Android面试]_24_Compose为什么绘制要比XML快?(高频问答)

欢迎查看合集&#xff1a; Java、Android面试高频系列文章合集 本人今年参加了很多面试&#xff0c;也有幸拿到了一些大厂的offer&#xff0c;整理了众多面试资料&#xff0c;后续还会分享众多面试资料。 整理成了面试系列&#xff0c;由于时间有限&#xff0c;每天整理一点&am…

常见公式的几何解释

本文旨在深入探讨常见数学公式的几何意义&#xff0c;通过直观的图形和解释&#xff0c;帮助读者更好地理解并掌握这些公式的本质。文章首先概述了公式与几何图形之间的紧密联系&#xff0c;然后选取了几个典型的数学公式&#xff0c;进行详细解析。每个公式都将配以相应的几何…

Linux操作系统·进程管理

一、什么是进程 1.作业和进程的概念 Linux是一个多用户多任务的操作系统。多用户是指多个用户可以在同一时间使用计算机系统&#xff1b;多任务是指Linux可以同时执行几个任务&#xff0c;它可以在还未执行完一个任务时又执行另一项任务。为了完成这些任务&#xff0c;系统上…

数据库基础--MySQL简介以及基础MySQL操作

数据库概述 数据库&#xff08;DATABASE&#xff0c;简称DB&#xff09; 定义:是按照数据结构来组织、存储和管理数据的仓库.保存有组织的数据的容器(通常是一个文件或一组文件) 数据库管理系统(Database Management System,简称DBMS) 专门用于管理数据库的计算机系统软件;…

【补充】图神经网络前传——图论

本文作为对图神经网络的补充。主要内容是看书。 仅包含Introduction to Graph Theory前五章以及其他相关书籍的相关内容&#xff08;如果后续在实践中发现前五章不够&#xff0c;会补上剩余内容&#xff09; 引入 什么是图&#xff1f; 如上图所示的路线图和电路图都可以使用…

【Spring Cloud】服务容错中间件Sentinel入门

文章目录 什么是 SentinelSentinel 具有以下特征&#xff1a;Sentinel分为两个部分: 安装 Sentinel 控制台下载jar包&#xff0c;解压到文件夹启动控制台访问了解控制台的使用原理 微服务集成 Sentinel添加依赖增加配置测试用例编写启动程序 实现接口限流总结 欢迎来到阿Q社区 …

【介绍下Unity编辑器扩展】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

【docker】Spring Boot3.x 打包 Docker容器

Docker化Spring Boot应用 创建文件夹 demo mkdir democd demo创建Dockerfile # 两个 openjdk 二选一 #FROM openjdk:17-jre-alpineFROM eclipse-temurin:17MAINTAINER chengxuyuanshitang <chengxuyuanshitangXX.com>RUN mkdir -p /workspace/java/demoCOPY demo.ja…

Android 11 裁剪系统显示区域(适配异形屏)

概述 在显示技术中&#xff0c;"OverScan"&#xff08;超扫描&#xff09;是一种调整显示图像边界的技术。通常情况下&#xff0c;OverScan 会在显示屏的边缘周围裁剪一小部分图像。这种裁剪是为了确保显示内容在屏幕上的完整可见性&#xff0c;尤其是在老式电视或投…

C++入门基础(二)

目录 缺省参数缺省参数概念缺省参数分类全缺省参数半缺省参数声明与定义分离 缺省参数的应用 函数重载函数重载概念例子1 参数类型不同例子2 参数的个数不同例子3 参数的顺序不同 C支持函数重载的原理--名字修饰(name Mangling) 感谢各位大佬对我的支持,如果我的文章对你有用,欢…

Visual Studio导入libtorch(Cuda版)

Visual Studio导入libtorch&#xff08;Cuda版&#xff09; 一、安装 官网&#xff1a;https://pytorch.org/get-started/locally/ 相应地选择并下载 二、环境变量配置 解压zip&#xff0c;得到libtorch文件夹&#xff0c;将libtorch\lib和libtorch\bin对应路径添加到系统环…

使 Elasticsearch 和 Lucene 成为最佳向量数据库:速度提高 8 倍,效率提高 32 倍

作者&#xff1a;来自 Elastic Mayya Sharipova, Benjamin Trent, Jim Ferenczi Elasticsearch 和 Lucene 成绩单&#xff1a;值得注意的速度和效率投资 我们 Elastic 的使命是将 Apache Lucene 打造成最佳的向量数据库&#xff0c;并继续提升 Elasticsearch 作为搜索和 RAG&a…

【JVM】简述类加载器及双亲委派机制

双亲委派模型&#xff0c;是加载class文件的一种机制。在介绍双亲委派模型之前&#xff0c;我需要先介绍几种类加载器&#xff08;Class Loader&#xff09;。 1&#xff0c;类加载器 Bootstrap&#xff0c;加载lib/rt.jar&#xff0c;charset.jar等中的核心类&#xff0c;由…

JWT是什么?如何使用?

JWT是什么&#xff1f;如何使用&#xff1f; 前言什么是JWT&#xff1f;概念工作方式JWT的组成HeaderPayloadSignatrue 实战引入依赖自定义注解定义实体类定义一个JWT工具类业务校验并生成token定义拦截器配置拦截器定义接口方法并添加注解开始验证 使用场景注意事项 JWT与传统…

ASR语音转录Prompt优化

ASR语音转录Prompt优化 一、前言 在ASR转录的时候&#xff0c;我们能很明显的感受到有时候语音识别不是很准确&#xff0c;这过程中常见的文本错误主要可以归纳为以下几类&#xff1a; 同音错误&#xff08;Homophone Errors&#xff09; 同音错误发生在不同词语发音相似或相…

用Excel做一个功能完备的仓库管理系统

1 基本设计思路 用到的Excel技术&#xff1a;sumif, vlookup, 表格(table)。基本思路&#xff1a;在有基础的商品、仓库等信息的情况下&#xff0c;对商品的每一个操作都有对应的单据&#xff0c;然后再汇总统计。标识&#xff1a;为了在不同的维度统计数量&#xff0c;各单据…

谷粒商城实战(020 RabbitMQ-消息确认)

Java项目《谷粒商城》架构师级Java项目实战&#xff0c;对标阿里P6-P7&#xff0c;全网最强 总时长 104:45:00 共408P 此文章包含第258p-第p261的内容 消息确认 生产者 publishers 消费者 consumers 设置配置类 调用api 控制台 抵达brocker 代理 新版本ReturnCallbac…

matlab学习005-利用matlab设计滤波器

目录 一&#xff0c;含有多个频率成分的三角信号 1&#xff0c;以采样频率fs20KHz对信号采样&#xff0c; 画出信号的波形&#xff1b; 1&#xff09;前期基础 2&#xff09;波形图 3&#xff09;代码 2&#xff0c;选取合适的采样点数&#xff0c;利用DFT分析信号的…

Baidu Comate:“AI +”让软件研发更高效更安全

4月27日&#xff0c;百度副总裁陈洋出席由全国工商联主办的第64届德胜门大讲堂&#xff0c;并发表了《深化大模型技术创新与应用落地&#xff0c;护航大模型产业平稳健康发展》主题演讲。陈洋表示&#xff0c;“人工智能”成为催生新质生产力的重要引擎&#xff0c;对于企业而言…