深度学习基础知识 register_buffer 与 register_parameter用法分析

深度学习基础知识 register_buffer 与 register_parameter用法分析

  • 1、问题引入
  • 2、register_parameter()
    • 2.1 作用
    • 2.2 用法
  • 3、register_buffer()
    • 3.1 作用
    • 3.2 用法

1、问题引入

思考问题:定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

验证方案:定义网络模型,设置weigut与bias,遍历网络结构参数net.named_parameters(),如果定义的weight与bias在里面,则说明是可学习参数;否则,是不可学习参数

import torch
import torch.nn as nn# 思考两个问题,定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习class MyModule(nn.Module):def __init__(self):super(MyModule,self).__init__()self.conv1=nn.Conv2d(in_channels= 3,out_channels= 6,kernel_size=3,stride = 1,padding=1,bias=False)self.conv2=nn.Conv2d(in_channels= 6,out_channels= 9,kernel_size=3,stride = 1,padding=1,bias=False)self.waight=torch.ones(10,10)self.bias=torch.zeros(10)def forward(self,x):x=self.conv1(x)x=self.conv2(x)x = x * self.weight + self.biasreturn xnet=MyModule()for name,param in net.named_parameters():  # 如果weight与bias在里面,说明其是可学习参数;否则,是不可学习参数print(name,param.shape)print("\n","-"*40,"\n")for key,val in net.state_dict().items():  # 说明weight与bias是不会被state_dict转化为字典中的元素的print(key,val.shape)

打印分析结果:
在这里插入图片描述
可以看到,weight与bias不在其中,所以此种定义方式不会是的weight与bias成为可训练参数

2、register_parameter()

register_parameter()是 torch.nn.Module 类中的一个方法

2.1 作用

1、可将 self.weight 和 self.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习
2、self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

2.2 用法

register_parameter(name,param)

  • name:参数名称
  • param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,

否则报错如下
在这里插入图片描述

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*40, '\n')for key, val in net.state_dict().items():print(key, val.shape)

结果显示:
在这里插入图片描述

3、register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 将 self.weight 和 self.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习

  • self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)

  • 参数:可以被优化器更新 (requires_grad=False / True)
  • buffer 中的数据 : 不会被优化器更新

3.2 用法

register_buffer(name,tensor)

  • name:参数名称
  • tensor:张量

代码:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_buffer('weight', torch.ones(10, 10))   # 注意:定义的方式self.register_buffer('bias', torch.zeros(10))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*40, '\n')for key, val in net.state_dict().items():print(key, val.shape)

效果如下所示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/155584.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

想用macbook录制视频?这几个技巧让你事半功倍!

在今天的数字时代,录制视频已经变得非常常见,无论是为了制作内容、演示、教育还是娱乐。如果您是一名macBook用户,您拥有一些强大的工具来满足您的视频录制需求。本文将向您介绍两种macbook录制视频的方法,帮助您更好地录制高质量…

SQL利用Case When Then多条件判断

CASE WHEN 条件1 THEN 结果1 WHEN 条件2 THEN 结果2 WHEN 条件3 THEN 结果3 WHEN 条件4 THEN 结果4 ......... WHEN 条件N THEN 结果N ELSE 结果X END Case具有两种格式。简单Case函数和Case搜索函数。 --简单Case函数 CASE sex WHEN 1 THEN…

东芝Z750的画质真实吗?适合看纪录片吗?

东芝Z750显示的画面更加真实、细腻、有层次感,就感觉电视中的画面像真实世界一般呈现在眼前,东芝电视拥有70余年的原色调校技术,每款产品都有专属的日本调校工程师匠心打造,可以真实还原画面色彩,若是观看类似《航拍中国》这样的旅游记录片时,东芝电视将会营造出万里山河的壮阔…

从0开始学go第五天

gin框架返回JSON package mainimport ("net/http""github.com/gin-gonic/gin" )func main() {r : gin.Default()r.GET("/json", func(c *gin.Context) {//用map序列化//方法一:用map,后面用接口类型// data : map[string…

利用人工智能做射击游戏辅助(二)AlphaPose环境配置

一、anaconda安装 官网地址:Free Download | AnacondaAnacondas open-source Distribution is the easiest way to perform Python/R data science and machine learning on a single machine.https://www.anaconda.com/download 下载之后下一步就可以&#xff0c…

谷歌云 | 零售行业的生成式 AI:如何跟上步伐并取得领先

【Cloud Ace 是 Google Cloud 全球战略合作伙伴,在亚太地区、欧洲、南北美洲和非洲拥有二十多个办公室。Cloud Ace 在谷歌专业领域认证及专业知识目前排名全球第一位,并连续多次获得 Google Cloud 各类奖项。作为谷歌云托管服务商,我们提供谷…

[Python]图片转字符画——这就是我的表情!!!!!!

背景 偶尔看到一些视频,他们把图片转字符画,平常也没有去关注,今天来捣鼓一下。 研究了一下还超级简单的,都是调用别人写好的框架。 网上也有很多教学。 代码实现 from PIL import Image # 表示字符颜色,由深到浅&am…

Redis 的过期键 | Navicat 技术干货

Redis 是一种高性能的内存数据存储,以其速度和多功能性而闻名。其中一个有用的特性是为键设置过期时间的功能。在 Redis 中,为键设置过期时间对于管理数据和确保过时或临时数据自动从数据库中删除是至关重要的。在本文中,我们将探讨在 redis-…

【AI】深度学习——前馈神经网络——全连接前馈神经网络

文章目录 1.1 全连接前馈神经网络1.1.1 符号说明超参数参数活性值 1.1.2 信息传播公式通用近似定理 1.1.3 神经网络与机器学习结合二分类问题多分类问题 1.1.4 参数学习矩阵求导链式法则更为高效的参数学习反向传播算法目标计算 ∂ z ( l ) ∂ w i j ( l ) \frac{\partial z^{…

应对广告虚假流量,app广告变现该如何风控?

移动广告市场中的虚假流量一直是困扰各移动应用厂商的难题,广告作为app商业化变现最为直接快捷的途径,也引申出了流量作弊与反作弊的纷争。 根据《2021中国异常流量报告》,2021年中国品牌广告市场因异常流量造成的损失约为326亿人民币&#…

用例图 UML从入门到放弃系列之三

1.说明 关于用例图,这篇文章我将直接照搬罗伯特.C.马丁老爷子在《敏捷开发》一书种的第17章,并配上自己的理解,因为这一章写的实在是太精彩了,希望能够分享给大家,共勉。以下是老爷子的原文中文翻译以及豆芽的个人解读…

早安问候语早安心语,别把人生想太难,人生需要鼓励

1、别把人生想的太难,人生需要几分自我的鼓励,不管在什么时候,要有几分信念和信心,生活少不了哭哭笑笑。青山绿水依然在,来来往往人不同,要学会看得惯,还要学会看得开,你内心的平坦是…

【Java学习之道】继承与多态

引言 本文将介绍面向对象编程的核心概念——继承与多态。对于初学者来说,掌握这些基本概念是迈向Java高手的第一步。接下来,让我们一起揭开继承与多态的神秘面纱,感受它们的魅力吧! 一、继承 继承是面向对象编程的一个重要特性…

Linux 文件系统

目录 磁盘文件管理 认识磁盘 抽象认识磁盘 磁盘划分 inode vs 文件名 软硬链接 磁盘文件管理 前面我们说了关于 Linux 文件系统中 “已打开的文件” ,但是在系统中可不光只有已打开的文件,实际上,系统中还存在很多没有打开的文件。 既…

1600*C. Game On Leaves(博弈游戏树)

Problem - 1363C - Codeforces 解析: 我们将目标结点 x 当作树的根,显然,到当 x 的度为 1 的时候,此时行动的人胜利。 我们假设现在的情况为,只剩余三个点,再选择任意一个点,则对方获胜。但是两…

WSL2下的Docker配置和使用

在Windows的Linux子系统(Windows Subsystem for Linux)WSL2中安装、配置和使用 Docker,可以参考官方教程:WSL上的Docker远程容器入门. 重要步骤总结如下: 先决条件 确保你的计算机运行的是 Windows 10(更…

C++11 Qt QFutureWatcher lambda

目录 Lambda 介绍 【QT】Qt之QFutureWatcher 简述 传参: 还可以使用 QProgressDialog 作为阻堵 函数,变成同步; 完成后,关闭; MyQProgressDialog 效果: Lambda 介绍 Lambda 函数也叫匿名函数&…

信创办公–基于WPS的PPT最佳实践系列 (绘制自选图形)

信创办公–基于WPS的PPT最佳实践系列 (绘制自选图形) 目录 应用背景操作步骤1、记忆复制:CTRLD2、微移:CTRL四个方向键 应用背景 如果想将文字转为简单而形象的smartart图形,但是又找不到自己想要的图形,我…

什么是大数据,大数据简介

大数据的概念通俗的说法 大数据,按照我的理解比较通俗易懂的是在数据量很多很大的情况下数据处理速度需要足够快,用我们以前传统意义上的的技术比如关系型数据库mysql没办法处理或者处理起来非常复杂,必须有一些新的处理技术也就是大数据处理…

2024第八届杭州国际智慧城市博览会:建筑与智能,智慧与未来

浙江,中国最具活力的省份之一,将再次迎来一场盛大的智慧城市行业展会。2024年第八届浙江智慧城市博览会,由浙江省土木建筑学会发起主办,以“探索未来,智能引领”为主题,于2024年4月份在美丽的杭州国际博览中…