【深度学习_TensorFlow】梯度下降

写在前面

一直不太理解梯度下降算法是什么意思,今天我们就解开它神秘的面纱


写在中间

线性回归方程


如果要求出一条直线,我们只需知道直线上的两个不重合的点,就可以通过解方程组来求出直线

但是,如果我们选取的这两个点不在直线上,而是存在误差(暂且称作观测误差),这样求出的直线就会和原直线相差很大,我们应该怎样做呢?首先肯定不能只通过两个点,就武断地求出这条直线。

在这里插入图片描述

我们通常尽可能多地使用分布在直线周围的点,也可能不存在一条直线完美的穿过所有采样点。那么,退而求其次,我们希望能找到一条比较“好”的位于采样点中间的直线。那么怎么衡量“好”与“不好”呢?一个很自然的想法就是,求出当前模型的所有采样点上的预测值𝑤𝑥(𝑖) + 𝑏与真实值𝑦(𝑖)之间的差的平方和作为总误差 L \mathcal{L} L,然后搜索一组参数 w ∗ , b ∗ w^{*},b^{*} w,b使得 L \mathcal{L} L最小,对应的直线就是我们要寻找的最优直线。

w ∗ , b ∗ = arg ⁡ min ⁡ w , b 1 n ∑ i = 1 n ( w x ( i ) + b − y ( i ) ) 2 w^*,b^*=\arg\min_{w,b}\frac{1}{n}\sum_{i=1}^{n}\bigl(wx^{(i)}+b-y^{(i)}\bigr)^2 w,b=argminw,bn1i=1n(wx(i)+by(i))2

最后再通过梯度下降法来不断优化参数 w ∗ , b ∗ w^{*},b^{*} w,b

有基础的小伙伴们可能知道求误差的方法其实就是均方误差函数,不懂得可以看这篇文章补充养分《误差函数》 ,我们这篇文章就侧重梯度下降。

梯度下降


函数的梯度定义为函数对各个自变量的偏导数组成的向量。不会的话,翻翻高等数学下册书。

举个例子,对于曲面函数𝑧 = 𝑓(𝑥, 𝑦),函数对自变量𝑥的偏导数记为 ∂ z ∂ x \frac{\partial z}{\partial x} xz,函数对自变量𝑦的偏导数记为 ∂ z ∂ y \frac{\partial z}{\partial y} yz,则梯度∇𝑓为向量 ( ∂ z ∂ x , ∂ z ∂ y ) ({\frac{\partial z}{\partial x}},{\frac{\partial z}{\partial y}}) (xz,yz),梯度的方向总是指向当前位置函数值增速最大的方向,函数曲面越陡峭,梯度的模也越大。

函数在各处的梯度方向∇𝑓总是指向函数值增大的方向,那么梯度的反方向−∇𝑓应指向函数值减少的方向。利用这一性质,我们只需要按照下式来更新参数,,其中𝜂用来缩放梯度向量,一般设置为某较小的值,如 0.01、0.001 等。

x ′ = x − η ⋅ d y d x x'=x-\eta\cdot\frac{\mathrm{d}y}{\mathrm{d}x} x=xηdxdy

结合上面的回归方程,我们就可对误差函数求偏导,以循环的方式更新参数 w , b w,b w,b

w ′ = w − η ∂ L ∂ w b ′ = b − η ∂ L ∂ b \begin{aligned}w'&=w-\eta\frac{\partial\mathcal{L}}{\partial w}\\\\b'&=b-\eta\frac{\partial\mathcal{L}}{\partial b}\end{aligned} wb=wηwL=bηbL

函数实现


计算过程都需要包裹在 with tf.GradientTape() as tape 上下文中,使得前向计算时能够保存计算图信息,方便自动求导操作。通过tape.gradient()函数求得网络参数到梯度信息,结果保存在 grads 列表变量中。

GradientTape()函数

GradientTape(persistent=False, watch_accessed_variables=True)

  • persistent: 布尔值,用来指定新创建的gradient
    tape是否是可持续性的。默认是False,意味着只能够调用一次GradientTape()函数,再次使用会报错

  • watch_accessed_variables:布尔值,表明GradientTape()函数是否会自动追踪任何能被训练的变量。默认是True。要是为False的话,意味着你需要手动去指定你想追踪的那些变量。

tape.watch()函数

tape.watch()用于跟踪指定类型的tensor变量。

  • 由于GradientTape()默认只对tf.Variable类型的变量进行监控。如果需要监控的变量是tensor类型,则需要tape.watch()来监控,否则输出结果将是None

tape.gradient()函数

tape.gradient(target, source)

  • target:求导的因变量

  • source:求导的自变量

import tensorflow as tfw = tf.constant(1.)
x = tf.constant(2.)
y = x * wwith tf.GradientTape() as tape:tape.watch([w])y = x * wgrads = tape.gradient(y, [w])
print(grads)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/75744.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

eclipse was unable to locate its companion shared library

当转移或者Copy工程时, eclipse was unable to locate its companion shared library eclipse.ini 里面的路径配置错误导致 --launcher.library C:/Users/**/.p2/pool/plugins/org.eclipse.equinox. launcher.win32.win32.x86_64_1.2.700.v20221108-1024 -product …

ubuntu下,在vscode中使用platformio出现 Can not find working Python 3.6+ Interpreter的问题

有一段时间没有使用platformio了,今天突然使用的时候,发现用不了,报错: Ubuntu PlatformIO: Can not find working Python 3.6 Interpreter. Please install the latest Python 3 and restart VSCode。 上网一查,发现…

Java版工程行业管理系统源码-专业的工程管理软件-em提供一站式服务

​ Java版工程项目管理系统 Spring CloudSpring BootMybatisVueElementUI前后端分离 功能清单如下: 首页 工作台:待办工作、消息通知、预警信息,点击可进入相应的列表 项目进度图表:选择(总体或单个)项目…

Docker 安装

1、下载 去官网下载Docker: Accelerated, Containerized Application Development 下载完成就ok了,下载之后估计是打不开的,我下了一个 去微软商店找wsl然后下载,下载之后Docker就好了。 这个就能用了

72. 编辑距离

题目链接:力扣 解题思路:

Modbus tcp转ETHERCAT网关modbus tcp/ip协议

捷米JM-ECT-TCP网关能够连接到Modbus tcp总线和ETHERCAT总线中,实现两种不同协议设备之间的通讯。这个网关能够大大提高工业生产的效率和生产效益,让生产变得更加智能化。捷米JM-ECT-TCP 是自主研发的一款 ETHERCAT 从站功能的通讯网关。该产品主要功能是…

Jetson Docker 编译 FFmpeg 支持硬解nvmpi和cuvid

0 设备和docker信息 设备为NVIDIA Jetson Xavier NX,jetpack版本为 5.1.1 [L4T 35.3.1] 使用的docker镜像为nvcr.io/nvidia/l4t-ml:r35.2.1-py3,详见https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml 使用下列命令拉取镜像: sudo docker pull nvcr…

前端Vue入门-day06-路由进阶

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,请留下您的足迹) 目录 路由的封装抽离 声明式导航 导航链接 两个类名 自定义高亮类名 跳转传参 1. 查询参数传参 2. 动态…

Python接口自动化-requests模块之post请求

一、源码解析 def post(url, dataNone, jsonNone, **kwargs):r"""Sends a POST request.:param url: URL for the new :class:Request object.:param data: (optional) Dictionary, list of tuples, bytes, or file-likeobject to send in the body of the :cla…

Unity通过代码切换材质

效果展示 代码 using System.Collections; using System.Collections.Generic; using UnityEngine;public class MaterialSwitcher : MonoBehaviour {public Material newMaterial; // 新材质private Material oldMaterial; // 旧材质private Renderer renderer; // 渲染器组件…

秋招算法备战第37天 | 738.单调递增的数字、968.监控二叉树、贪心算法总结

738. 单调递增的数字 - 力扣(LeetCode) 这个问题是关于找到一个小于或等于给定数字n的最大单调递增数字。 我们可以将数字n转换为字符数组,然后从左到右扫描,寻找第一个违反单调递增条件的位置。一旦找到这样的位置,…

线上java程序CPU及内存占用过高问题排查总结

背景 最近发现线上的一个JAVA程序总是过段时间慢慢卡死,最后导致无法提供服务,外部请求接口超时。 经排查发现,该程序CPU及内存占用都很高,导致整个系统负载很高。 到这里,就想到了对程序内存进行分析。排查过程 查询…

直播预告|还在说做不出、改不好地图贴图?一次直播包教包会!

在EasyV中,地图组件通常会作为可视化大屏中的「主视觉」部分,用户通过地图组件的使用,可以极大程度上提高搭建的效率以及视觉效果。正因如此,我们的素材广场中大多模板也将「地图」作为核心部分,以此来方便用户快速套用…

golang函数传参——值传递理解

做了五年的go开发,却并没有什么成长,都停留在了业务层面了。一直以为golang中函数传参,如果传的是引用类型,则是以引用传递,造成这样的误解,实在也不能怪我。我们来看一个例子,众所周知&#xf…

#rust taur运行报错#

场景:在window11系统上运行 tauri桌面莹应用,提示错误。 Visual Studio 2022 生成工具 安装的sdk11 , rust运行模式是stable-x86_64-pc-window-gnu, 运行npm run tauir dev 一致失败,失败信息如下 原因:1:在window11系…

数字图像处理(番外)图像增强

图像增强 图像增强的方法是通过一定手段对原图像附加一些信息或变换数据,有选择地突出图像中感兴趣的特征或者抑制(掩盖)图像中某些不需要的特征,使图像与视觉响应特性相匹配。 图像对比度 图像对比度计算方式如下: C ∑ δ δ ( i , j …

【方法】PDF可以转换成Word文档吗?如何操作?

很多人喜欢在工作中使用PDF,因为PDF格式可以准确地保留文档的原始格式,比如字体、图像、布局和颜色等。 但如果编辑文档的话,PDF还是没有Word文档方便。那可以将PDF转换成Word格式,再来编辑吗?如何操作呢?…

Vue2 第二十节 vue-router (一)

1.相关概念理解 2.基本路由 3.嵌套路由(多级路由) 一.相关概念理解 1.1 vue-router的理解 路由:就是一组key-value的对应关系, key为路径,value可能是function或者component多个路由,需要经过路由器的管理编程中的…

从0到1开发go-tcp框架【2-实现Message模块、解决TCP粘包问题、实现多路由机制】

从0到1开发go-tcp框架【2-实现Message模块、解决TCP粘包问题、实现多路由机制】 1 实现\封装Message模块 zinx/ziface/imessage.go package zifacetype IMessage interface {GetMsdId() uint32GetMsgLen() uint32GetMsgData() []byteSetMsgId(uint32)SetData([]byte)SetData…

docker容器互联详解

目录 docker容器互联详解 一、容器互联概述: 二、案例实验: 1、用户自定义的网络: 2、查看当前的IP信息: 3、启动第三个容器: 4、查看三个容器内部的网络: 5、Ping测试: Ps备注&#x…