pytorch求导

pytorch求导的初步认识

requires_grad

tensor(data, dtype=None, device=None, requires_grad=False)

requires_grad是torch.tensor类的一个属性。如果设置为True,它会告诉PyTorch跟踪对该张量的操作,允许在反向传播期间计算梯度。

x.requires_grad    判断一个tensor是否可以求导,返回布尔值

叶子变量-leaf variable

  • 对于requires_grad=False 的张量,我们约定俗成地把它们归为叶子张量。
  • 对于requires_grad为True的张量,如果他们是由用户创建的,则它们是叶张量。

 如果某一个叶子变量,开始时不可导的,后面想设置它可导,该怎么办?

x.requires_grad_(True/False)   设置tensor的可导与不可导

注意:这种方法只适用于设置叶子变量,否则会出现如下错误

x = torch.tensor(2.0, requires_grad=True)
y = torch.pow(x, 2)
z = torch.add(y, 3)
z.backward()
print(x.grad)
print(y.grad)
tensor(4.)
None
  1. 创建一个浮点型张量x,其值为2.0,并设置requires_grad=True,使PyTorch可以跟踪x的计算历史并允许计算它的梯度。

  2. 创建一个新张量y,y是x的平方。

  3. 创建一个新张量z,z是y和3的和。

  4. 调用z.backward()进行反向传播,计算z关于x的梯度。

  5. 打印x的梯度,应该是2*x=4.0。

  6. 试图打印y的梯度。但是,PyTorch默认只计算并保留叶子节点的梯度非叶子节点的梯度在计算过程中会被释放掉,因此y的梯度应该为None。

保留中间变量的梯度

tensor.retain_grad()

 retain_grad()retain_graph是用来处理两个不同的情况

  1. retain_grad(): 用于保留非叶子节点的梯度。如果你想在反向传播结束后查看或使用非叶子节点的梯度,你应该在非叶子节点上调用.retain_grad()

  2. retain_graph: 当你调用.backward()时,PyTorch会自动清除计算图以释放内存。这意味着你不能在同一个计算图上多次调用.backward()。但是,如果你需要多次调用.backward()(例如在某些特定的优化算法中),你可以在调用.backward()时设置retain_graph=True保留计算图

.grad

通过tensor的grad属性查看所求得的梯度值。

.grad_fn

在PyTorch中,.grad_fn属性是一个引用到创建该Tensor的Function对象。也就是说,这个属性可以告诉你这个张量是如何生成的。对于由用户直接创建的张量,它的.grad_fnNone。对于由某个操作创建的张量,.grad_fn将引用到一个与这个操作相关的对象

import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
z = y.mean()print(x.grad_fn)
print(y.grad_fn)
print(z.grad_fn)

这里,x是由用户直接创建的,所以x.grad_fnNoney是通过乘法操作创建的,所以y.grad_fn是一个MulBackward0对象,这表明y是通过乘法操作创建的。z是通过求平均数操作创建的,所以z.grad_fn是一个MeanBackward0对象。

 pytorch自动求导实现神经网络

numpy手动实现

import numpy as np
import matplotlib.pyplot as pltN, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)W1 = np.random.randn(D_in, H)  # 1000维转成100维
W2 = np.random.randn(H, D_out)  # 100维转成10维learning_rate = 1e-6all_loss = []epoch = 500for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''h = X.dot(W1)  # N * Hh_relu = np.maximum(h, 0)  # 激活函数,N * Hy_hat = h_relu.dot(W2)  # N * D_out'''计算损失函数(compute loss)'''loss = np.square(y_hat - y).sum()  # 均方误差,忽略了÷Nprint("Epoch:{}   Loss:{}".format(t, loss))  # 打印每个迭代的损失all_loss.append(loss)'''后向传播(backward pass)'''# 计算梯度(此处没用torch,用最普通的链式求导,最终要得到 d{loss}/dX)grad_y_hat = 2.0 * (y_hat - y)  # d{loss}/d{y_hat},N * D_outgrad_W2 = h_relu.T.dot(grad_y_hat)  # 看前向传播中的第三个式子,d{loss}/d{W2},H * D_outgrad_h_relu = grad_y_hat.dot(W2.T)  # 看前向传播中的第三个式子,d{loss}/d{h_relu},N * Hgrad_h = grad_h_relu.copy()  # 这是h>0时的情况,d{h_relu}/d{h}=1grad_h[h < 0] = 0  # d{loss}/d{h}grad_W1 = X.T.dot(grad_h)  # 看前向传播中的第一个式子,d{loss}/d{W1}'''参数更新(update weights of W1 and W2)'''W1 -= learning_rate * grad_W1W2 -= learning_rate * grad_W2plt.plot(all_loss)
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()

pytorch自动实现

import torchN, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)W1 = torch.randn(D_in, H, requires_grad=True)  # 1000维转成100维
W2 = torch.randn(H, D_out, requires_grad=True)  # 100维转成10维learning_rate = 1e-6for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''y_hat = X.mm(W1).clamp(min=0).mm(W2)  # N * D_out'''计算损失函数(compute loss)'''loss = (y_hat - y).pow(2).sum()  # 均方误差,忽略了÷N,loss就是一个计算图(computation graph)print("Epoch:{}   Loss:{}".format(t, loss.item()))  # 打印每个迭代的损失'''后向传播(backward pass)'''loss.backward()'''参数更新(update weights of W1 and W2)'''with torch.no_grad():W1 -= learning_rate * W1.gradW2 -= learning_rate * W2.gradW1.grad.zero_()W2.grad.zero_()

pytorch手动实现

import torch
import matplotlib.pyplot as pltN, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)W1 = torch.randn(D_in, H)  # 1000维转成100维
W2 = torch.randn(H, D_out)  # 100维转成10维learning_rate = 1e-6all_loss = []for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''h = X.mm(W1)  # N * Hh_relu = h.clamp(min=0)  # 激活函数,N * Hy_hat = h_relu.mm(W2)  # N * D_out'''计算损失函数(compute loss)'''loss = (y_hat - y).pow(2).sum().item()  # 均方误差,忽略了÷Nprint("Epoch:{}   Loss:{}".format(t, loss))  # 打印每个迭代的损失all_loss.append(loss)'''后向传播(backward pass)'''# 计算梯度(此处没用torch,用最普通的链式求导,最终要得到 d{loss}/dX)grad_y_hat = 2.0 * (y_hat - y)  # d{loss}/d{y_hat},N * D_outgrad_W2 = h_relu.t().mm(grad_y_hat)  # 看前向传播中的第三个式子,d{loss}/d{W2},H * D_outgrad_h_relu = grad_y_hat.mm(W2.t())  # 看前向传播中的第三个式子,d{loss}/d{h_relu},N * Hgrad_h = grad_h_relu.clone()  # 这是h>0时的情况,d{h_relu}/d{h}=1grad_h[h < 0] = 0  # d{loss}/d{h}grad_W1 = X.t().mm(grad_h)  # 看前向传播中的第一个式子,d{loss}/d{W1}'''参数更新(update weights of W1 and W2)'''W1 -= learning_rate * grad_W1W2 -= learning_rate * grad_W2plt.plot(all_loss)
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()

torch.nn实现

import torch
import torch.nn as nn  # 各种定义 neural network 的方法N, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)model = torch.nn.Sequential(torch.nn.Linear(D_in, H, bias=True),  # W1 * X + b,默认Truetorch.nn.ReLU(),torch.nn.Linear(H, D_out)
)# model = model.cuda()  #这是使用GPU的情况loss_fn = nn.MSELoss(reduction='sum')learning_rate = 1e-4for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''y_hat = model(X)  # model(X) = model.forward(X), N * D_out'''计算损失函数(compute loss)'''loss = loss_fn(y_hat, y)  # 均方误差,忽略了÷N,loss就是一个计算图(computation graph)print("Epoch:{}   Loss:{}".format(t, loss.item()))  # 打印每个迭代的损失'''后向传播(backward pass)'''loss.backward()'''参数更新(update weights of W1 and W2)'''with torch.no_grad():for param in model.parameters():param -= learning_rate * param.grad  # 模型中所有的参数更新model.zero_grad()

torch.nn的继承类

import torch
import torch.nn as nn  # 各种定义 neural network 的方法
from torchsummary import summary
# pip install torchsummary
N, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)'''定义两层网络'''class TwoLayerNet(torch.nn.Module):def __init__(self, D_in, H, D_out):super(TwoLayerNet, self).__init__()# 定义模型结构self.linear1 = torch.nn.Linear(D_in, H, bias=False)self.linear2 = torch.nn.Linear(H, D_out, bias=False)def forward(self, x):y_hat = self.linear2(self.linear1(X).clamp(min=0))return y_hatmodel = TwoLayerNet(D_in, H, D_out)loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''y_hat = model(X)  # model.forward(), N * D_out'''计算损失函数(compute loss)'''loss = loss_fn(y_hat, y)  # 均方误差,忽略了÷N,loss就是一个计算图(computation graph)print("Epoch:{}   Loss:{}".format(t, loss.item()))  # 打印每个迭代的损失optimizer.zero_grad()  # 求导之前把 gradient 清空'''后向传播(backward pass)'''loss.backward()'''参数更新(update weights of W1 and W2)'''optimizer.step()  # 一步把所有参数全更新print(summary(model, (64, 1000)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/81483.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB的设置路径

在主页下的 或者在命令行输入path&#xff0c;命令行会出现所有路径 必须要将某些函数.m文件以及一些类文件包含在路径当中&#xff0c;否则在脚本代码中输入代码时&#xff0c;不会有代码提示

根据渲染数据长度动态渲染后缀图标

在动态获取数据时&#xff0c;想要渲染后面的图标是根据数据的长度渲染图标位置&#xff0c;效果如下&#xff1a; 代码如下&#xff1a; <el-row :gutter"60"><el-col :span"24"><el-form-item><el-input v-model.trim"form…

图的适配器

什么是图 图是一个由点的集合和边的集合所构成的数据结构。 图分为有向图和无向图。其中无向图也可以理解为有向图&#xff0c;所以可以认为所有的图都是有向图。 比方说&#xff0c;有这么一张图。其中a指向bc&#xff0c;b指向c&#xff0c;c指向p。边是带方向的&#xff0c;…

使用 RediSearch 在 Redis 中进行全文检索

原文链接&#xff1a; 使用 RediSearch 在 Redis 中进行全文检索 Redis 大家肯定都不陌生了&#xff0c;作为一种快速、高性能的键值存储数据库&#xff0c;广泛应用于缓存、队列、会话存储等方面。 然而&#xff0c;Redis 在原生状态下并不支持全文检索功能&#xff0c;这使…

FFmpeg常见命令行(三):FFmpeg转码

前言 在Android音视频开发中&#xff0c;网上知识点过于零碎&#xff0c;自学起来难度非常大&#xff0c;不过音视频大牛Jhuster提出了《Android 音视频从入门到提高 - 任务列表》。本文是Android音视频任务列表的其中一个&#xff0c; 对应的要学习的内容是&#xff1a;如何使…

数学建模—多元线性回归分析(+lasso回归的操作)

第一部分&#xff1a;回归分析的介绍 定义&#xff1a;回归分析是数据分析中最基础也是最重要的分析工具&#xff0c;绝大多数的数据分析问题&#xff0c;都可以使用回归的思想来解决。回归分析的人数就是&#xff0c;通过研究自变量X和因变量Y的相关关系&#xff0c;尝试去解释…

基于ffmpeg与SDL的视频播放库

由于工作需要&#xff0c;自己封装的基于ffmpeg的视频编解码库&#xff0c;显示采用了SDL库。可以播放本地文件或网络流&#xff0c;支持多端口播放&#xff0c;支持文字叠加&#xff0c;截图、视频录制等等。 头文件代码&#xff1a; #pragma once #ifdef __DLLEXPORT #defin…

如何用正确的姿势监听Android屏幕旋转

作者&#xff1a;37手游移动客户端团队 背景 关于个人&#xff0c;前段时间由于业务太忙&#xff0c;所以一直没有来得及思考并且沉淀点东西&#xff1b;同时组内一个个都在业务上能有自己的思考和总结&#xff0c;在这样的氛围下&#xff0c;不由自主的驱使周末开始写点东西&…

antv/l7地图,鼠标滚动,页面正常滑动-- 我们忽略的deltaY

背景 在官网项目中&#xff0c;需要使用一个地图&#xff0c;展示产品的分布区域及数量。希望的交互是&#xff0c;鼠标放上标点&#xff0c;tooltip展示地点和数量等信息。鼠标滚动&#xff0c;则页面随着滚动。但是鼠标事件是被地图代理了的&#xff0c;鼠标滚动意味着地图的…

【vue3】基础知识点-computed和watch

学习vue3&#xff0c;都会从基础知识点学起。了解setup函数&#xff0c;ref&#xff0c;recative&#xff0c;watch、computed、pinia等如何使用 今天说vue3组合式api&#xff0c;computed和watch 在vue3中&#xff0c;computed和watch仍然是非常有用的特性&#xff0c;帮助处…

JavaWeb学习|JSP相关内容

1.什么是JSP Java Server Pages: Java服务器端页面&#xff0c;也和Servlet一样&#xff0c;用于动态Web技术! 最大的特点: 。写JSP就像在写HTML 。区别: 。HTML只给用户提供静态的数据 。JSP页面中可以嵌入JAVA代码&#xff0c;为用户提供动态数据 JSP最终也会被转换成为一…

Android数据存储选项:SQLite、Room等

Android数据存储选项&#xff1a;SQLite、Room等 1. 引言 在移动应用的开发过程中&#xff0c;数据存储是至关重要的一环。无论是用户的个人信息、设置配置还是应用产生的临时数据&#xff0c;都需要在设备上进行存储以便随时访问。随着移动应用的日益发展&#xff0c;数据存…

springboot-mybatis的增删改查

目录 一、准备工作 二、常用配置 三、尝试 四、增删改查 1、增加 2、删除 3、修改 4、查询 五、XML的映射方法 一、准备工作 实施前的准备工作&#xff1a; 准备数据库表 创建一个新的springboot工程&#xff0c;选择引入对应的起步依赖&#xff08;mybatis、mysql驱动…

Flowable-网关-事件网关

目录 定义图形标记XML内容使用示例视频教程 定义 通常网关根据连线条件来决定后继路径&#xff0c;但事件网关不同&#xff0c;它提供了根据事件做选择的方式。 事件网关的每个外出顺序流都需要连接至一个捕获中间事件。当流程执行到达事件网关时&#xff0c;网关类 似处于等待…

【阻止IE强制跳转到Edge浏览器】

由于微软开始限制用户使用Internet Explorer浏览网站&#xff0c;IE浏览器打开一些网页时会自动跳转到新版Edge浏览器&#xff0c;那应该怎么禁止跳转呢&#xff1f; 1、点击电脑左下角的“搜索框”或者按一下windows键。 2、输入“internet”&#xff0c;点击【Internet选项…

Python Opencv实践 - 在图像上绘制图形

import cv2 as cv import numpy as np import matplotlib.pyplot as pltimg cv.imread("../SampleImages/pomeranian.png") print(img.shape)plt.imshow(img[:,:,::-1])#画直线 #cv.line(img,start,end,color,thickness) #参考资料&#xff1a;https://blog.csdn.ne…

百度智能云“千帆大模型平台”最新升级:接入Llama 2等33个模型!

今年3月&#xff0c;百度智能云推出“千帆大模型平台”。作为全球首个一站式的企业级大模型平台&#xff0c;千帆不但提供包括文心一言在内的大模型服务及第三方大模型服务&#xff0c;还提供大模型开发和应用的整套工具链&#xff0c;能够帮助企业解决大模型开发和应用过程中的…

04-2_Qt 5.9 C++开发指南_SpinBox使用

文章目录 1. SpinBox简介2. SpinBox使用2.1 可视化UI设计2.2 widget.h2.3 widget.cpp 1. SpinBox简介 QSpinBox 用于整数的显示和输入&#xff0c;一般显示十进制数&#xff0c;也可以显示二进制、十六进制的数&#xff0c;而且可以在显示框中增加前缀或后缀。 QDoubleSpinBox…

stl_list类(使用+实现)(C++)

list 一、list-简单介绍二、list的常用接口1.常见构造2.iterator的使用3.Capacity和Element access4.Modifiers5.list的迭代器失效 三、list实现四、vector 和 list 对比五、迭代器1.迭代器的实现2.迭代器的分类&#xff08;按照功能分类&#xff09;3.反向迭代器(1)、包装逻辑…

嵌入式开发学习(STC51-9-led点阵)

内容 点亮一个点&#xff1b; 显示数字&#xff1b; 显示图像&#xff1b; LED点阵简介 LED 点阵是由发光二极管排列组成的显示器件 通常应用较多的是8 * 8点阵&#xff0c;然后使用多个8 * 8点阵可组成不同分辨率的LED点阵显示屏&#xff0c;比如16 * 16点阵可以使用4个8 *…