11-08 周三 图解机器学习之实现逻辑异或,理解输出层误差和隐藏层误差项和动量因子

11-08 周三 图解机器学习之实现逻辑异或,理解输出层误差和隐藏层误差项
时间版本修改人描述
2023年11月8日14:36:36V0.1宋全恒新建文档

简介

 最近笔者完成了《图解机器学习》这本书的阅读,由于最近深度学习网络大行其是,所以也想要好好的弄清楚神经网络的工作原理。比如说训练、比如说验证,比如说权重更新,之前也曾经写过两个博客来描述感知机和BP算法示意。

  • 10-09 周一 图解机器学习之深度学习感知机学习
  • 11-06 周一 神经网络之前向传播和反向传播代码实战

 反向传播这个博客里主要通过一个样本,来不断的更新参数,但实际的神经网络结构是不会像博客中name简单的,因此还是需要给出一个计算公式的。在阅读图解机器学习P169页,如下代码时,自己没有看懂:

        # 计算输出层误差for k in range(self.no):error = targets[k] - self.ao[k]output_deltas[k] = dsigmoid(self.ao[k]) * error# 计算隐藏层的误差hidden_deltas = [0.0]*self.nhfor j in range(self.nh):error = 0for k in range(self.no):error = error + output_deltas[k] * self.wo[j][k]hidden_deltas[j] = dsigmoid(self.ah[j]) * error# 更新输出层权重for j in range(self.nh):for k in range(self.no):change = output_deltas[k]*self.ah[j]self.wo[j][k] = self.wo[j][k] + N*change + M * self.co[j][k]

 上述在计算过程中求出了输出层误差和隐藏层误差项。如何理解这个代码片段呢?

完整代码

import math
import random
import string
random.seed(0)# 生成区间[a, b)内的随机数
def rand(a, b):return (b - a) *random.random() + a# 生成I*J大小的矩阵, 默认零矩阵
def makeMatrix(I, J, fill=0.0):m = []for i in range(I):m.append([fill]*J)return m# 函数 sigmoid, 采用tanh函数, 比起标准的1/(1+exp(-x))更好
def sigmoid(x):return math.tanh(x)# 函数sigmoid的派生函数 tanh(x)' = 1 - tanh(x)^2
def dsigmoid(x):return 1.0 - x**2class BPNeuralNet:'''建立三层反向传播神经网络'''def __init__(self, ni, nh, no) -> None:self.ni = ni + 1self.nh = nhself.no = no# 激活神经网络的所有节点self.ai = [1.0]* self.niself.ah = [1.0]*self.nhself.ao = [1.0]* self.no# 建立权重矩阵self.wi = makeMatrix(self.ni, self.nh)self.wo = makeMatrix(self.nh, self.no)# 设为随机值for i in range(self.ni):for j in range(self.nh):self.wi[i][j] = rand(-0.2, 0.2)for i in range(self.nh):for j in range(self.no):self.wo[i][j] = rand(-2.0, 2.0)# 建立动量因子self.ci = makeMatrix(self.ni, self.nh)self.co = makeMatrix(self.nh, self.no)# 前向传播,得到预计的输出。# 各个神经元的输出分别位于self.ah 和self.ao# inputs 代表一个样本def fp(self, inputs):if len(inputs) != self.ni -1:raise ValueError('与输入层节点数不符错误!')for i in range(self.ni-1):self.ai[i] = inputs[i]for j in range(self.nh):sum = 0.0for i in range(self.ni):sum += self.ai[i]* self.wi[i][j]self.ah[j] = sigmoid(sum)# 激活输出层for j in range(self.no):sum = 0for i in range(self.nh):sum += self.ah[i]*self.wo[i][j]self.ao[j] = sigmoid(sum)return self.ao[:]# N 学习速率 learning factor# M 动量因子 momentum factor# 基本思路是直接求出每个神经元的误差def back_propagate(self, targets, N, M):'''反向传播'''if len(targets) != self.no:raise ValueError("与输出层节点数不符!")output_deltas = [0.0] * self.no# 计算输出层误差for k in range(self.no):error = targets[k] - self.ao[k]output_deltas[k] = dsigmoid(self.ao[k]) * error# 计算隐藏层的误差hidden_deltas = [0.0]*self.nhfor j in range(self.nh):error = 0for k in range(self.no):error = error + output_deltas[k] * self.wo[j][k]hidden_deltas[j] = dsigmoid(self.ah[j]) * error# 更新输出层权重for j in range(self.nh):for k in range(self.no):change = output_deltas[k]*self.ah[j]self.wo[j][k] = self.wo[j][k] + N*change + M * self.co[j][k]self.co[j][k] = change# 更新输入层权重for i in range(self.ni):for j in range(self.nh):change=hidden_deltas[j]*self.ai[i]self.wi[i][j] += N * change + M * self.ci[i][j]self.ci[i][j] = changeerror = 0.0for k in range(len(targets)):error = error + 0.5*(targets[k]-self.ao[k])**2return errordef test(self, patterns):for p in patterns:print(p[0], '->', self.fp(p[0]))def weights(self):print('输入层权重')for i in range(self.ni):print(self.wi[i])print()print("输出层权重")for j in range(self.nh):print(self.wo[j])def train(self, patterns, iterations=100000, N =0.5, M=0.1):for i in range(iterations):error = 0.0for p in patterns:inputs = p[0]targets = p[1]self.fp(inputs)error = error + self.back_propagate(targets, N, M)if i % 100 ==0:print('计算误差的值是: %-.5f'%error)def trainprog():# BP神经网络学习逻辑异或pat = [[[0, 0], [0]],[[0, 1], [1]],[[1, 0], [1]],[[1, 1], [0]]]# 创建一个神经网络,输入层两个节点, 输出层两个节点,输出层一个节点:net = BPNeuralNet(2, 3, 1)net.train(pat)# 测试训练的成果net.test(pat)if __name__ == '__main__':trainprog()

 上述代码在理解上并不复杂,主要是通过三层神经网络来拟合逻辑异或运算。采用的是个案更新的策略来更新权重参数。

权重更新

基础知识

 多层前馈神经网络。

 一个示例: 奶酪是否喜爱。

 为此我们构建一个神经网络:

激活函数

  激活函数

  • 指数函数
  • sigmoid
  • 逻辑回归

 校正因子的概念如下:

 权重更新的策略有多种:

  • 个案更新 case-based 更容易得到准确的结果。
  • 批量更新 batch 优点就是比较快,加速。

迭代终止条件

 迭代终止条件:

  • 当权重和偏置差异与上一次非常小
  • 误差达到之前设置的阈值
  • 运行次数

存疑代码

        # 计算输出层误差for k in range(self.no):error = targets[k] - self.ao[k]output_deltas[k] = dsigmoid(self.ao[k]) * error# 计算隐藏层的误差hidden_deltas = [0.0]*self.nhfor j in range(self.nh):error = 0for k in range(self.no):error = error + output_deltas[k] * self.wo[j][k]hidden_deltas[j] = dsigmoid(self.ah[j]) * error# 更新输出层权重for j in range(self.nh):for k in range(self.no):change = output_deltas[k]*self.ah[j]self.wo[j][k] = self.wo[j][k] + N*change + M * self.co[j][k]self.co[j][k] = change# 更新输入层权重for i in range(self.ni):for j in range(self.nh):change=hidden_deltas[j]*self.ai[i]self.wi[i][j] += N * change + M * self.ci[i][j]self.ci[i][j] = change

 上述分别计算出了输出层的误差项和输出层的误差项。按照上述代码理解,前两个for循环用于计算误差项,后两个循环用来更新权重,顺序从后向前,这也是反向传播得名的由来。关键是为什么输出层的误差是这么得来的呢?

 参考 BP神经网络-第6集 反向传播误差,调整全部权重,这对于理解是非常关键的。

 我们以同样的方式,就可以得到每个神经元的误差。如下图

 可以采用矩阵相乘的方法

 权重通过矩阵乘表示。

gpt辅助理解

 自己还是无法理解,但感觉输出层的误差项与选用的损失函数密切相关,因此,笔者询问了GPT,得到了如下的结果:

  • 为什么要乘以激活函数的导数?
  • 交叉熵损失函数的输出层误差项
  • 均方差 输出层误差项:

 由此,我们可以得到如下的图示:

在计算h1节点的误差项时,输出层两个误差项以w7 和 w8进行作用,进而可以得到h1神经元的误差项:

errorh1=w7*e1 + w8 * e2。 依次可以得到h1, h2, h3神经元的误差项损失。

动量因子

 代码片段,其实整体贯彻了P166 图解机器学习图13.10,

 相互映照,也可以通过代码来理解上述的过程:

 代码参见 11-06 周一 神经网络之前向传播和反向传播代码实战

总结

 这部分的代码片段比之前的全部手动计算权重更新的过程复杂一些,因为抽象出了输出层误差项和隐藏层误差项,代码的抽象知识更加复杂了,但 BP神经网络-第6集 反向传播误差,调整全部权重则直接给出了误差项的矩阵乘表示,而这种方式,应该也是机器学习库中默认使用的方式吧。总之,这篇文章试图解释《图解机器学习》中第13章深度学习网络的代码,弄清楚其中权重更新的方式,包括为什么使用动量因子进行更新这种优化技术。希望读者能够读懂,进而在自己的工程实践中使用深度学习解决自己的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/185084.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Unity】零基础实现塔防游戏中敌人沿固定路径移动的功能

目录 场景搭建 烘焙(Bake) 敌人动作控制 脚本实现 我们知道,在一些塔防小游戏中,敌人往往会沿着给定的一条路径移动,我们在条路的路边会布置防御设施,攻击消灭敌人,阻止敌人到达终点。 场景搭建 我们首先新建一个…

Leetcode 第 369 场周赛题解

Leetcode 第 369 场周赛题解 Leetcode 第 369 场周赛题解题目1:2917. 找出数组中的 K-or 值思路代码复杂度分析 题目2:2918. 数组的最小相等和思路代码复杂度分析 题目3:2919. 使数组变美的最小增量运算数思路代码复杂度分析 题目4&#xff1…

Git 行结束符:LF will be replaced by CRLF the next time Git touches it问题解决指南

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

ruoyi-vue前端数据字典值引用与回显(列表中回显,多选框回显)

1. 列表中回显&#xff1a; 代码&#xff1a; <el-table v-if"refreshTable" v-loading"loading" :data"deptList" row-key"deptId" :default-expand-all"isExpandAll" :tree-props"{children: children, hasChil…

BP神经网络的数据分类——语音特征信号分类

大家好&#xff0c;我是带我去滑雪&#xff01; BP神经网络&#xff0c;也称为反向传播神经网络&#xff0c;是一种常用于分类和回归任务的人工神经网络&#xff08;ANN&#xff09;类型。它是一种前馈神经网络&#xff0c;通常包括输入层、一个或多个隐藏层和输出层。BP神经网…

虚幻C+++基础 day2

角色移动与视角控制 Character类与相关API 创建Character子类MainPlayer.h // Fill out your copyright notice in the Description page of Project Settings.#pragma once#include "CoreMinimal.h" #include "GameFramework/Character.h" #include &q…

遭受网络攻击泄露了101GB数据

臭名昭著的BlackCat/ALPHV勒索软件团伙声称对另一个组织发起了攻击。今天轮到意大利-法国科西嘉-费里斯公司发现自己正在与勒索软件作斗争。 BlackCat 在其数据泄露网站上报告称&#xff0c;该公司是网络攻击的受害者&#xff0c;并发布了从该公司 IT 基础设施中泄露的一系列样…

javaEE进阶

Cookie 是可以伪造的,比如说学生证是可以伪造的 Session 是不可以伪造的,这是学校系统记录在册的 如何获取 Cookie 我们先用 Servlet 原生的获取 cookie 的方式 我们在浏览器进行访问 但是实际上目前是没有 cookie 的,我们按 F12 进行添加 然后再重新访问,就能在 idea 看到 …

nginx下载安装和日志切割

目录 一、nginx安装配置 1.nginx版本 2.nginx安装配置 3.查看安装后的nginx 4.配置PATH变量 二、日志切割 1.给当前日志文件重命名 2.等待 3.写bash脚本 4.查看日志结果 5.加入crontab定时任务 结语 一、nginx安装配置 1.nginx版本 nginx如今分为商业版&#xff0…

SpringBoot定时任务打成jar 引入到新的项目中后并自动执行

一、springBoot开发定时任务 ①&#xff1a;连接数据库实现新增功能 1. 引入依赖 <dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><optional>true</optional> </dependency> <dependen…

数据库索引详解

目录 第一章、快速了解索引1.1&#xff09;索引是什么1.2&#xff09;为什么使用索引1.3&#xff09;快速上手创建简单索引 第二章、索引分类2.1&#xff09;按数据结构分类2.1.1&#xff09;树型数据结构的索引①二叉树②B树③B 树&#xff1a;B 树的升级版 2.1.2&#xff09;…

Unity地面交互效果——4、制作地面凹陷轨迹

大家好&#xff0c;我是阿赵。   上一篇介绍了曲面细分着色器的基本用法和思路&#xff0c;这一篇在曲面细分的基础上&#xff0c;制作地面凹陷的轨迹效果。 一、思路分析 这次需要达到的效果是这样的&#xff1a; 从效果上看&#xff0c;这个凹陷在地面下的轨迹&#xff0…

RabbitMQ 死信队列

在MQ中&#xff0c;当消息成为死信&#xff08;Dead message&#xff09;后&#xff0c;消息中间件可以将其从当前队列发送到另一个队列中&#xff0c;这个队列就是死信队列。而在RabbitMQ中&#xff0c;由于有交换机的概念&#xff0c;实际是将死信发送给了死信交换机&#xf…

【VUE+ elementUI 实现动态表头渲染】

VUE elementUI 实现动态表头渲染 1、定义 columns&#xff08;表头数据&#xff09; 和 dataList&#xff08;表格数据&#xff09; data() {return {loading: false,dataList: [{ name: 张三, sex: 男, age: 18 },{ name: 林琳, sex: 女, age: 20 },{ name: 王五, sex: 男, …

基于减法平均算法的无人机航迹规划-附代码

基于减法平均算法的无人机航迹规划 文章目录 基于减法平均算法的无人机航迹规划1.减法平均搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要&#xff1a;本文主要介绍利用减法平均算法来优化无人机航迹规划。 …

SpringBoot整合Canal+RabbitMQ监听数据变更(对rabbit进行模块封装)

SpringBootCanal(监听MySQL的binlog)RabbitMQ&#xff08;处理保存变更记录&#xff09; 在SpringBoot中采用一种与业务代码解耦合的方式&#xff0c;来实现数据的变更记录&#xff0c;记录的内容是新数据&#xff0c;如果是更新操作还得有旧数据内容。 使用Canal来监听MySQL的…

open clip论文阅读摘要

看下open clip论文 Learning Transferable Visual Models From Natural Language Supervision These results suggest that the aggregate supervision accessible to modern pre-training methods within web-scale collections of text surpasses that of high-quality crowd…

基于React开发的chatgpt网页版(仿chatgpt)

在浏览github的时候发现了一个好玩的项目本项目&#xff0c;是github大神Yidadaa开发的chatgpt网页版&#xff0c;该开源项目是跨平台的&#xff0c;Web / PWA / Linux / Win / MacOS都可以访问。非常有意思&#xff0c;本人就部署了一套&#xff0c;喜欢的同学可以体验一番。 …

Python之字符串、正则表达式练习

目录 1、输出随机字符串2、货币的转换&#xff08;字符串 crr107&#xff09;3、凯撒加密&#xff08;book 实验 19&#xff09;4、字符替换5、检测字母或数字6、纠正字母7、输出英文中所有长度为3个字母的单词 1、输出随机字符串 编写程序&#xff0c;输出由英文字母大小写或…

wpf添加Halcon的窗口控件报错:下列控件已成功添加到工具箱中,但未在活动设计器中启用

报错截图如下&#xff1a; 注意一下新建工程的时候选择wpf应用而不是wpf应用程序。 添加成功的控件&#xff1a;