深度学习02-pytorch-08-自动微分模块

​​​​​​​

其实自动微分模块,就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。

反向传播的目的: 更新参数, 所以会使用到自动微分模块。

神经网络传输的数据都是 float32 类型。 

案例1:

代码功能概述:

该代码展示了如何在 PyTorch 中使用 自动微分(Autograd) 计算损失函数相对于权重 w 和偏置 b 的梯度。这是机器学习模型训练中非常重要的步骤,因为这些梯度将用于更新模型的参数,从而最小化损失函数

import torch# 1. 当x为标量时,梯度的计算
def test01():x = torch.tensor(5)  # 输入变量x为标量5# 目标值y = torch.tensor(0.)  # 目标输出y设置为0# 设置要更新的权重 和 偏置的初始值w = torch.tensor(1., requires_grad=True, dtype=torch.float32)  # 权重w初始化为1,并启用梯度计算b = torch.tensor(3., requires_grad=True, dtype=torch.float32)  # 偏置b初始化为3,并启用梯度计算# 设置网络的输出值z = x * w + b  # 计算线性模型的输出 z = x*w + b (等同于线性回归的公式)# 设置损失函数,并进行损失的计算loss = torch.nn.MSELoss()  # 使用均方误差(MSE)作为损失函数loss1 = loss(z, y)  # 计算损失,z 是模型的预测值,y 是目标值# 自动微分,计算损失函数相对于w和b的梯度loss1.backward()  # 反向传播计算梯度# backward 函数计算的梯度值会存储在张量的grad 变量中print("w的梯度", w.grad)  # 打印出损失函数对 w 的梯度print("b的梯度", b.grad)  # 打印出损失函数对 b 的梯度test01() 

w的梯度 tensor(80.)
b的梯度 tensor(16.)

代码讲解:

    1.    输入与目标值:
    •    x = torch.tensor(5):输入为 x = 5,表示输入的特征值。
    •    y = torch.tensor(0.):目标输出 y 设置为 0,这是我们希望模型最终预测得到的值。
    2.    参数的初始化:
    •    w = torch.tensor(1., requires_grad=True):初始化权重 w 为 1,requires_grad=True 启用对 w 的梯度计算。
    •    b = torch.tensor(3., requires_grad=True):初始化偏置 b 为 3,同样启用对 b 的梯度计算。
requires_grad=True 的作用是让 PyTorch 知道我们想对这些参数进行梯度计算。
    3.    模型计算:
    •    z = x * w + b:计算模型的输出,类似于线性回归的公式。z 是模型的预测输出。
    4.    损失函数:
    •    loss = torch.nn.MSELoss():选择均方误差(MSE)作为损失函数,用于衡量预测值 z 与目标值 y 之间的误差。
    •    loss1 = loss(z, y):计算损失值,z 是模型预测输出,y 是目标值。

MSE 的公式为:

\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (z_i - y_i)^2

在这个例子中,由于我们只使用了一个数据点,损失计算为:

\text{Loss} = (z - y)^2 = (x \cdot w + b - 0)^2

    5.    反向传播:
    •   loss1.backward():通过调用 backward(),PyTorch 会自动计算损失函数对 w 和 b 的梯度。这个过程称为反向传播(Backpropagation)。梯度的计算基于链式法则,PyTorch 会自动追踪所有的计算操作,计算各个参数对损失的导数。


    6.    梯度输出:
    •    w.grad:存储了损失函数对 w 的梯度。
    •    b.grad:存储了损失函数对 b 的梯度。

案例2:

import torchdef test02():# 输入张量 2x5,表示 2 个样本,每个样本有 5 个特征x = torch.ones(2, 5)  # 输入数据,全部初始化为 1# 目标输出张量 2x3,表示我们希望模型预测的输出有 3 个类别y = torch.zeros(2, 3)  # 目标输出,初始化为 0# 设置可更新的权重和偏置的初始值# 权重 w 的形状是 5x3,表示输入特征为 5,输出类别为 3w = torch.randn(5, 3, requires_grad=True)  # 随机初始化权重,启用梯度计算# 偏置 b 的形状是 3,表示每个输出类别有一个偏置b = torch.randn(3, requires_grad=True)  # 随机初始化偏置,启用梯度计算# 计算网络的输出,z = x * w + b# x 的形状是 2x5,w 的形状是 5x3,矩阵乘法后的结果 z 的形状是 2x3z = torch.matmul(x, w) + b  # 矩阵乘法和偏置加法# 设置损失函数,并计算损失# 这里使用均方误差(MSE),z 是预测值,y 是目标值loss_fn = torch.nn.MSELoss()  # 损失函数为均方误差loss = loss_fn(z, y)  # 计算损失,输出一个标量值# 自动微分,计算损失函数相对于 w 和 b 的梯度loss.backward()  # 反向传播,计算梯度# 打印权重和偏置的梯度,梯度值存储在 grad 属性中print("w 的梯度:\n", w.grad)  # 打印权重 w 的梯度print("b 的梯度:\n", b.grad)  # 打印偏置 b 的梯度# 调用函数进行计算
test02()

自动微分 (Autograd) 的工作原理:

    •    PyTorch 中的 Autograd 是自动微分引擎,它会记录所有张量的计算历史,并根据这些计算图自动执行反向传播,计算参数的梯度。
    •    在向前计算过程中,PyTorch 构建了一个动态计算图(计算图是有向无环图 DAG)。当你调用 .backward(),计算图会根据链式法则从损失开始计算每个变量的梯度。
    •    计算的梯度会存储在对应张量的 .grad 属性中,然后可以使用这些梯度来更新模型的参数。

总结:

    •    w.grad 和 b.grad 的值告诉我们,若我们改变 w 或 b,损失函数会如何变化。
    •    梯度的计算对于优化模型非常重要,因为我们会使用这些梯度来更新权重和偏置,使得损失函数最小化。

PyTorch 中的 自动微分模块 是通过 autograd 实现的,这是 PyTorch 中的核心功能之一,它可以帮助用户在神经网络的训练过程中自动计算梯度。autograd 模块使得实现反向传播和梯度计算变得非常简单和高效。

核心概念

  1. Tensor: PyTorch 的张量 (Tensor) 是自动微分系统的基本单位。如果将 Tensorrequires_grad 属性设置为 True,则 PyTorch 会开始跟踪所有与该张量相关的操作,并在反向传播时自动计算该张量的梯度。

  2. Computational Graph (计算图): PyTorch 会构建一个动态图,记录张量的所有操作。这个图是有向无环图(DAG),图中的每个节点代表一个变量,边代表该变量上发生的操作。当你调用 .backward() 时,PyTorch 会根据计算图自动计算每个张量的梯度。

  3. 梯度 (Gradient): 如果一个张量参与了计算并且 requires_grad=True,在反向传播时可以通过 .grad 属性获取其梯度值。

  4. 反向传播: 通过 tensor.backward() 来执行反向传播计算张量的梯度,默认情况下会对标量进行求导。

使用案例

  1. 创建一个张量并启用梯度跟踪:

    import torch
    ​
    # 创建一个张量,并启用梯度跟踪
    x = torch.tensor([[2.0, 3.0]], requires_grad=True)

  2. 执行一些操作:

    y = x * 3
    z = y.sum()
    print(z)

  3. 反向传播:

    z.backward()  # 对 z 求导
    print(x.grad)  # 查看 x 的梯度

    输出:

    tensor([[3., 3.]])

    在这个例子中,z = x * 3z.backward() 计算了 zx 的梯度,结果为 3

PyTorch 自动微分的几个重要点:

  1. requires_grad=True: 如果需要对某个张量求导,必须将其 requires_grad 属性设置为 True,否则在反向传播时 PyTorch 不会计算该张量的梯度。

  2. grad_fn: 每个跟踪计算的张量都有一个 grad_fn 属性,代表该张量的创建方式和跟踪的操作。例如,如果你对一个张量做了加法操作,它的 grad_fn 就会显示 AddBackward0

    print(y.grad_fn)  # <MulBackward0 object at 0x...>

  3. .backward(): backward() 方法会根据计算图反向传播,自动计算梯度。

  4. 梯度累加: 每次调用 backward() 时,梯度会被累加到 .grad 中,因此在多次反向传播之前,最好手动将 .grad 清零,使用 x.grad.zero_()

autograd 的典型使用场景

  • 神经网络训练:通过 autograd,我们可以在每次迭代时计算损失函数的梯度,然后使用这些梯度更新网络的参数。

  • 自定义梯度计算:可以通过创建复杂的操作来自动推导梯度。

Example: 简单的线性回归

import torch
​
# 生成数据
x = torch.randn(10, 1, requires_grad=True)
y = 3 * x + 2
​
# 定义损失函数
loss = (x - y).pow(2).mean()
​
# 反向传播
loss.backward()
​
# 查看 x 的梯度
print(x.grad)

在这个例子中,loss.backward() 会自动计算 xloss 的梯度。

总结

  • PyTorch 的自动微分机制通过 autograd 实现,用户只需要将张量的 requires_grad 设置为 True,在执行反向传播时,PyTorch 会自动计算张量的梯度。

  • 通过自动构建计算图,autograd 能够跟踪张量上的所有操作,动态计算梯度,极大地方便了深度学习模型的训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/429582.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙Harmony应用开发,数据驾驶舱 项目结构搭建

对于一个项目而言&#xff0c;在拿到我们的开发任务后&#xff0c;我们最重要的就是技术的选型。选型定下来了之后我们便开始脚手架的搭建&#xff0c;然后开始撸代码&#xff0c;开搞. 首先我们需要对一些常见依赖库的引入 我们需要再oh-package.json5的dependencies节点下面…

8--SpringBoot原理分析、注解-详解(面试高频提问点)

目录 SpringBootApplication 1.元注解 --->元注解 Target Retention Documented Inherited 2.SpringBootConfiguration Configuration Component Indexed 3.EnableAutoConfiguration&#xff08;自动配置核心注解&#xff09; 4.ComponentScan Conditional Co…

基于PHP的新闻管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于phpMySQL的新闻管理系统。…

JavaWeb--纯小白笔记03:servlet入门---动态网页的创建

笔记&#xff1a;index.html在tomcat中为默认的名字&#xff0c;html里面的语法不严谨。改配置文件要小心&#xff0c;不然容易删掉其他 Servlet&#xff1a;服务器端小程序&#xff0c;写动态网页需要用Servlet&#xff0c;普通的java类通过继承HttpServlet&#xff0c;可以响…

【GUI设计】基于Matlab的图像处理GUI系统(1),用matlab实现

博主简介&#xff1a;matlab图像代码项目合作&#xff08;扣扣&#xff1a;3249726188&#xff09; ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于Matlab的图像处理GUI系统&#xff0c;用matlab实现。 本次内容主要分为两部分&a…

Why is OpenAI image generation Api returning 400 bad request in Unity?

题意&#xff1a;为什么 OpenAI 图像生成 API 在 Unity 中返回 400 Bad Request 错误&#xff1f; 问题背景&#xff1a; Im testing out dynamically generating images using OpenAI API in Unity. Amusingly, I actually generated most of this code from chatGPT. 我正在…

【笔记】第二节 轧制、热处理和焊接工艺

2.2 钢轨的轧制工艺 坯料进厂按标准验收, 然后装加热炉加热, 加热好的钢坯经高压水除鳞后进行轧制。轧出的钢轨经锯切、打印到中央冷床冷却, 然后装缓冷坑进行缓冷。缓冷后的钢轨进行矫直、轨端加工和端头淬火。钢轨入库前逐根进行探伤和外观检查。 钢轨的轧制 #mermaid-svg-…

foreach,for in和for of的区别

forEach 不能使用break return 结束并退出循环 for in 和 for of 可以使用break return&#xff1b; for in 遍历的是数组的索引&#xff08;即键名&#xff09;&#xff0c;而for of遍历的是数组元素值。 for of 遍历的只是数组内的元素&#xff0c;而不包括数组的原型属性…

后端-navicat查找语句(单表与多表)

表格字段设置如图 语句&#xff1a; 1.输出 1.输出name和age列 SELECT name,age from student 1.2.全部输出 select * from student 2.where子语句 1.运算符&#xff1a; 等于 >大于 >大于等于 <小于 <小于等于 ! <>不等于 select * from stude…

JdbcTemplate常用方法一览AG网页参数绑定与数据寻址实操

JdbcTemplate是Spring框架中的一个重要组件&#xff0c;主要用于简化JDBC数据库操作。它提供了许多常用的方法&#xff0c;如查询、插入、更新、删除等。本文将介绍JdbcTemplate的常用方法及其使用方式&#xff0c;以及参数绑定和删除数据的方法。 一、JdbcTemplate常用方法 查…

钉钉与MySQL对接集成获取部门列表2.0打通EXECUTE语句

钉钉与MySQL对接集成获取部门列表2.0打通EXECUTE语句 接入系统&#xff1a;钉钉 钉钉是阿里巴巴集团打造的企业级智能移动办公平台&#xff0c;是数字经济时代的企业组织协同办公和应用开发平台。钉钉将IM即时沟通、钉钉文档、钉闪会、钉盘、Teambition、OA审批、智能人事、钉工…

828华为云征文|华为Flexus云服务器搭建Cloudreve私人网盘

一、华为云 Flexus X 实例&#xff1a;开启高效云服务新篇&#x1f31f; 在云计算的广阔领域中&#xff0c;资源的灵活配置与卓越性能犹如璀璨星辰般闪耀。华为云 Flexus X 实例恰似一颗最为耀眼的新星&#xff0c;将云服务器技术推向了崭新的高度。 华为云 Flexus X 实例基于…

基于SpringBoot+Vue的商城积分系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目源码、Python精…

我的AI工具箱Tauri版-MicrosoftTTS文本转语音

本教程基于自研的AI工具箱Tauri版进行MicrosoftTTS文本转语音服务。 MicrosoftTTS文本转语音服务 是自研的AI工具箱Tauri版中的一款功能模块&#xff0c;专为实现高效的文本转语音操作而设计。通过集成微软TTS服务&#xff0c;用户可以将大量文本自动转换为自然流畅的语音文件…

物理学基础精解【9】

文章目录 直线与二元一次方程两直线夹角直线方程斜率两点式方程截距式方程将不同形式的直线方程转换为截距方程直线的一般方程直线一般方程的系数有一个或两个为零的直线 参考文献 直线与二元一次方程 两直线夹角 两直线 y 1 k 1 x b 1 , y 2 k 2 x b 2 形成夹角 a 1 和 a…

关于字节 c++

字节的介绍 字节是计算机中最小的存储单位&#xff0c;通常由8个二进制位组成&#xff0c;用来存储一个字符。在C中&#xff0c;字节也是基本数据类型之一&#xff0c;用关键字"byte"来表示。字节主要用于存储一些较小的数据&#xff0c;如整数、字符等。字节的大小…

【Delphi】通过 LiveBindings Designer 链接控件示例

本教程展示了如何使用 LiveBindings Designer 可视化地创建控件之间的 LiveBindings&#xff0c;以便创建只需很少或无需源代码的应用程序。 在本教程中&#xff0c;您将创建一个高清多设备应用程序&#xff0c;该应用程序使用 LiveBindings 绑定多个对象&#xff0c;以更改圆…

python - self 调用父类方法

Python 子类继承父类构造函数说明 | 菜鸟教程如果在子类中需要父类的构造方法就需要显式地调用父类的构造方法&#xff0c;或者不重写父类的构造方法。 子类不重写 __init__&#xff0c;实例化子类时&#xff0c;会自动调用父类定义的 __init__。 实例 [mycode3 typepython] cl…

Linux基础---13三剑客及正则表达式

一.划水阶段 首先我们先来一个三剑客与正则表达式混合使用的简单示例&#xff0c;大致了解是个啥玩意儿。下面我来演示一下如何查询登录失败的ip地址及次数。 1.首先&#xff0c;进入到 /var/log目录下 cd /var/log效果如下 2.最后&#xff0c;输入如下指令即可查看&#xf…

基于协同过滤+python+django+vue的音乐推荐系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于协同过滤pythondjangovue…