bp(back propagation)

文章目录

  • 定义
  • 过程
    • 前向传播计算过程
    • 计算损失函数(采用均方误差MSE)
    • 反向传播误差(链式法则)
    • 计算梯度
    • 更新参数
  • 简单实例

定义

反向传播全名是反向传播误差算法(Backpropagation),是一种监督学习方法,用于调整神经网络中权重参数,以最小化模型的预测误差。反向传播通过计算损失函数相对于网络参数的梯度,然后使用梯度下降等优化算法来更新参数,从而提高模型的性能

核心思想是从网络的输出层向输入层传播误差信号,然后根据这些误差信号来更新网络中的权重和偏差,以使模型的输出更接近实际目标

过程

前向传播:将输入数据输入到网络中,通过神经网络的前向传播计算出每个神经元的输出结果,直到输出层输出最终的结果。
计算损失函数:将神经网络输出的结果与真实标签进行比较,计算出损失函数的值。
反向传播误差:从输出层开始,计算每个神经元的误差,然后向前计算每个神经元的误差,直到计算出输入层每个神经元的误差。具体来说,我们首先计算输出层的误差,然后反向传播到前一层隐藏层,计算隐藏层的误差,并将误差反向传播到更早的层,直到计算出输入层的误差。这个过程可以使用链式法则来计算。
计算梯度:根据误差计算每个神经元的梯度,即损失函数对每个神经元权重和偏差的偏导数。
更新参数:使用梯度下降法或其他优化方法来更新每个神经元的权重和偏差,使得损失函数的值最小化。

设置初始值如图:
 目标:给出输入数据i1,i2(0.05和0.10),使输出尽可能与原始输出o1,o2(0.01和0.99)接近

神经网络中信息传递的基本原理:
在神经网络中,前一层神经元的输出通常作为下一层神经元的输入

前向传播计算过程

y=wx+b
例如上图中h1的输入input1 = w1i1+w2i2+b1
假设采用的sigmoid激活函数,则输出output1 = 1/(1+e^-input1)
而这个输出output1又会作为下一层神经元o1、o2的输入
由此我们可以轻松计算出最后的输出

计算损失函数(采用均方误差MSE)

将前向传播计算得到的输出与原始输出比较
MSE = Σ(预测值 - 实际观测值)^2 / 观测值的数量

反向传播误差(链式法则)

从输出层开始,计算每个神经元的误差,然后向前计算每个神经元的误差,直到计算出输入层每个神经元的误差

计算梯度

当前层的梯度 = 当前层输出的梯度 × 当前层输入的梯度
以b站某up主讲解中的一个片段为例:
其中绿色字体代表的是输入值,红色字体为反向传播的梯度

我们以最左上角w0,x0那为例,假设得到的是y0,(再次提醒,上面绿色字体为前向传播计算的值)
则x0的梯度为:y0对x0求导再乘以0.2(当前层输出的梯度),最后得到0.40(至于0.39可能是笔误)
同理w0的梯度:y0对w0求导再乘以0.2,即-1.0*0.2 = 0.20

更新参数

设置一个学习率,根据之前计算的梯度就可以进行参数权重更新了

简单实例

import numpy as np# 定义 Sigmoid 激活函数
def sigmoid(x):return 1 / (1 + np.exp(-x))# 定义 Sigmoid 激活函数的导数
def sigmoid_derivative(x):return x * (1 - x)# 创建神经网络类
class NeuralNetwork:def __init__(self, input_size, hidden_size, output_size):# 初始化权重self.weights_input_hidden = np.random.rand(input_size, hidden_size)self.weights_hidden_output = np.random.rand(hidden_size, output_size)def forward(self, X):# 前向传播self.hidden_input = np.dot(X, self.weights_input_hidden)self.hidden_output = sigmoid(self.hidden_input)self.output = np.dot(self.hidden_output, self.weights_hidden_output)return self.outputdef backward(self, X, y, output, learning_rate):# 反向传播error = y - outputdelta_output = errorerror_hidden = delta_output.dot(self.weights_hidden_output.T)delta_hidden = error_hidden * sigmoid_derivative(self.hidden_output)# 更新权重self.weights_hidden_output += self.hidden_output.T.dot(delta_output) * learning_rateself.weights_input_hidden += X.T.dot(delta_hidden) * learning_ratedef train(self, X, y, learning_rate, epochs):for i in range(epochs):output = self.forward(X)self.backward(X, y, output, learning_rate)if i % 1000 == 0:loss = np.mean(np.square(y - output))print(f"Epoch {i}, Loss: {loss}")# 示例数据
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])# 创建神经网络
input_size = 2
hidden_size = 4
output_size = 1
learning_rate = 0.1
epochs = 10000nn = NeuralNetwork(input_size, hidden_size, output_size)# 训练神经网络
nn.train(X, y, learning_rate, epochs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/165596.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python二次开发Solidworks:修改实体尺寸

立方体原始尺寸:100mm100mm100mm 修改后尺寸:10mm100mm100mm import win32com.client as win32 import pythoncomdef bin_width(width):myDimension Part.Parameter("D1草图1")myDimension.SystemValue width def bin_length(length):myDime…

【吞噬星空】又被骂,罗峰杀人目无法纪,但官方留后手,增加审判戏份

Hello,小伙伴们,我是小郑继续为大家深度解析国漫吞噬星空资讯。 吞噬星空动画中,罗峰复仇的戏份,简直是帅翻了,尤其是秒杀阿特金三大巨头,让人看的也是相当的解气,相当的爽,一点都不拖沓&#x…

TCP为什么需要三次握手和四次挥手?

一、三次握手 三次握手(Three-way Handshake)其实就是指建立一个TCP连接时,需要客户端和服务器总共发送3个包 主要作用就是为了确认双方的接收能力和发送能力是否正常、指定自己的初始化序列号为后面的可靠性传送做准备 过程如下&#xff…

力扣每日一题51:N皇后问题

题目描述: 按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数 n ,返回所有不同的 n 皇后问…

bitbucket.org 用法

这个网站需要魔法,注册完成后添加厂库时间2023.10 图1 图2 第二张图 ,不要.gitignore文件 sourcetree 1,创建前端项目 npm create vitelatest 2.打开vscode创建本地Git 看到Git代提交的文件 sourcetree,新建 已存在的本地厂库 提交到Git 添…

linux基础IO

文章目录 前言一、基础IO1、文件预备知识1.1 文件类的系统调用接口1.2 复习c语言接口 2、文件类的系统调用接口2.1 open系统调用2.2 close系统调用2.3 write系统调用2.4 read系统调用 3、文件描述符3.1 文件描述符fd介绍3.2 文件描述符fd分配规则与重定向3.3 重定向原理3.4输入…

(八)vtk常用类的常用函数介绍(附带代码示例)

vtk中类的说明以及函数使用 https://vtk.org/doc/nightly/html/annotated.html 一、vtkObject派生类 1.vtkPoints 点 InsertNextPoint(double, double, double):插入点。 2.vtkCellArray 单元数组 InsertNextCell (vtkIdType npts, const vtkIdType *pts)&…

java与c++中的交换方法

最近在写算法的时候,遇到一个问题。 java中编写swap(交换)方法还需要传入一个数组,但是在c中则不需要。 可以看到,在没有传入数组进行交换数组元素的时候,交换前与交换后的值是一样的。 而在c中&#xff…

笔记:绘图进阶

主要功能: 双坐标轴多子图共用一个横坐标横坐标时间刻度设置(方便) # -*- coding: utf-8 -*- import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdatesif __name__ __main__:# 风速da…

开源软件-禅道Zentao

禅道Zentao 简介漏洞复现SQL注入漏洞**16.5****router.class.php SQL注入** **v18.0-v18.3****后台命令执行** 远程命令执行漏洞(RCE)后台命令执行 简介 是一款开源的项目管理软件,旨在帮助团队组织和管理他们的项目。Zentao提供了丰富的功能…

10. 机器学习-评测指标

Hi,你好。我是茶桁。 之前的课程中,我们学习了两个最重要的回归方法,一个线性回归,一个逻辑回归。也讲解了为什么学习机器学习要从逻辑回归和线性回归讲起。因为我们在解决问题的时候,有限选择简单的假设,越复杂的模型…

jvm 各个版本支持的参数

知道一些 jvm 调优参数,但是没有找到官网对应的文档,在网上的一些文章偶然发现,记录一下。 https://docs.oracle.com/en/java/javase/ 包含各个版本 jdk 8 分为 windows 和 unix 系统 https://docs.oracle.com/javase/8/docs/technotes/too…

【ChatGLM2-6B】nginx转发配置

背景 好不容易把ChatGLM2-6B大语言模型部署好了,使用streamlit方式启动起来了,终于可以愉快的玩耍了,然后想着申请一个域名,使用HTTPS协议访问,但实践过程中,发现这个大语言模型的nginx转发配置还是有点小…

STM32F4x之中断一

一、中断简介 中断概念:程序在运行过程中发生了外部或内部事件时,导致中断了正在执行的程序,让CPU转到外部或内部事件中去执行。 中断的作用:大量节约CPU资源,提高程序的效率,即避免重要事件被错过。 中断…

利用TypeScript 和 jsdom 库实现自动化抓取数据

以下是一个使用 TypeScript 和 jsdom 库的下载器程序,用于下载zhihu的内容。此程序使用了 duoip.cn/get_proxy 这段代码。 import { JSDOM } from jsdom; import { getProxy } from https://www.duoip.cn/get_proxy;const zhihuUrl https://www.zhihu.com;(async (…

NFT Insider112:The Sandbox聘请Apple高管担任其首席内容官,YGG 将在菲律宾举办Web3游戏峰会

引言:NFT Insider由NFT收藏组织WHALE Members、BeepCrypto联合出品,浓缩每周NFT新闻,为大家带来关于NFT最全面、最新鲜、最有价值的讯息。每期周报将从NFT市场数据,艺术新闻类,游戏新闻类,虚拟世界类&#…

Qt扫盲-QTextCodec理论总结

QTextCodec理论总结 一、概述二、编码支持三、使用四、创建自己的编解码器类 一、概述 QTextCodec 是Qt提供的一个管理字符串编码的功能,他可以在不同编码方式中来回转换,在文件读取的时候、格式编码转换的时候用处很大。Qt使用Unicode 编码来存储、绘制…

Aocoda-RC F405V2 FC(STM32F405RGT6 v.s. AT32F435RGT7) IO Definitions

[TOC](Aocoda-RC F405V2 FC(STM32F405RGT6 v.s. AT32F435RGT7) IO Definitions) 1. 源由 Aocoda-RC F405V2飞控支持betaflight/inav/Ardupilot固件,是一款固件兼容性非常不错的开源硬件。 之前我们对比过STM32F405RGT6 v.s. AT32F435RGT7 Comparison for Flight …

java中的容器(集合),HashMap底层原理,ArrayList、LinkedList、Vector区别,hashMap加载因子0.75原因

一、java中的容器 集合主要分为Collection和Map两大接口;Collection集合的子接口有List、Set;List集合的实现类有ArrayList底层是数组、LinkedList底层是双向非循环列表、Vector;Set集合的实现类有HashSet、TreeSet;Map集合的实现…

freeipa server副本同步中断,两主节点数据不一致

/var/log/messages 和/var/log/dirsrv/slapd-testhadoop-COM 日志都出现以下日志: If replication stops, the consumer may need to be reinitialized. [27/Jun/2023:05:15:09.469361922 0800] - ERR - NSMMReplicationPlugin - changelog program - repl_plugin_name_cl - a…