register_parameter和register_buffer 详解

在参考yolo系列代码或其他开源代码,经常看到register_buffer register_parameter的使用,接下来将详细对他们进行介绍。

1. 前沿

在搭建网络时,我们 自定义的参数,往往不会保存到模型权重文件中,或者成为模型可学习的参数。即我们通过 net.named_parameters() (模型可学习参数)或 net.state_dict().items()(保存模型权重值)方法都无法遍历输出。那如何解决呢,这就需要用到本文讲的register_parameterregister_buffer方法。

2. register_parameter

register_parameter() 是 torch.nn.Module 类中的一个方法。

2.1 主要作用

  • 用于定义可学习参数
  • 定义的参数可被保存到网络对象的参数中,可使用 net.parameters()net.named_parameters() 查看
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件

2.2 函数说明

register_parameter(name,param)

参数:

  • name:参数名称

  • param:参数张量, 须是torch.nn.Parameter()对象 或 None ,否则报错如下
    TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

2.3 举例说明

(1)自定义的参数未使用register_parameter

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.weight = torch.ones(10,10)self.bias = torch.zeros(10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

输出:
在这里插入图片描述
在网络搭建的代码中,我们自定义了self.weightself.bias参数。我们思考下2个问题:1. 我们定义的self.weightself.bias参数是否会保存到网络的参数中,是否能在优化器的作用下进行学习。2. 这些参数是否能够保存到模型文件中,从而可以利用state_dict中遍历出来。通过上面的打印信息我们发现:

  • 使用net.named_parameters()迭代网络中可学习的参数,发现输出的参数只有conv1conv2的weight参数,并没有输出我们定义的self.weightself.bias
  • 接下来使用net.state_dict()方法迭代保存的参数,同样发现self.weightself.bias参数也没有被输出出来。

(2)通过register_parameter方法来定义参数

  • 接下来我们使用register_parameter来定义weight和bias参数,看看会有啥效果。代码修改如下:
self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))
self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))

完整代码

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述

  • 可以看到,使用了register_parameter定义的参数weight和bias,可以通过net.named_parameters或者net.parameters迭代出来的,这说明weight和bias已经存到了网络的参数中,他们是可学习的参数
  • 同时,通过state_dict()也能将参数和值给迭代出来,就说明如果要保存模型权重或网络参数时,这两个参数时可以被保存起来的。

3 register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 用于定义不可学习的参数
  • 定义的参数不会被保存到网络对象的参数中,使用 net.parameters() 或 net.named_parameters() 查看不到
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件中

register_buffer() 用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数),它与参数的区别为:

  • 参数:可以被优化器更新 (requires_grad=False / True)

  • buffer 中的数据 (不可学习): 不会被优化器更新

3.2、举例说明

将定义的weight和bias,通过register_buffer来定义。

self.register_buffer('weight',torch.ones(10,10))
self.register_buffer('bias',torch.zeros(10))

运行完整代码看看效果:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_buffer('weight',torch.ones(10,10))self.register_buffer('bias',torch.zeros(10))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()zprint('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述
我们可以看到:

  • 通过register_buffer定义的参数weight和bias,它是没有被named_parameter给迭代出来的,也就是说weight和bias不是网络的可学习参数,无法通过优化器来迭代更新,我们把它叫做buffer,而不是参数
  • 然而我们使用net.state_dict去迭代的话,weight和bias事可以被迭代出来的,这就说明使用register_buffer定义的数据,可以保持到模型或者权重文件中。

注意:

  • 在使用register_parameter定义参数时,必须定义为可学习的参数,因此需要通过torch.nn.Parameter去定义为一个可学习的参数
  • 而我们使用register_buffer定义参数时,是不需要通过torch.nn.Parameter去定义为可学习的参数的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/181087.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

操作系统复习(2)进程管理

一、概述 1.1程序的顺序执行 一个具有独立功能的程序独占CPU运行,直至得到最终结果的过程称为程序的顺序执行。 程序的并发执行所表现出的特性说明两个问题 ⑴ 程序和计算机执行程序的活动不再一一对应 ⑵ 并发程序间存在相互制约关系(要求共享信息&…

docker-compose 简单部署MySQL Database

docker-compose 简单部署MySQL Database 本博文部署MySQL 并与上篇部署的 Flask进行关联 主博客目录:《从零开始学习搭建量化平台笔记》 文章目录 docker-compose 简单部署MySQL Database部署 MySQLMySQL 开放端口与权限 主项目计划需要搭建一个MySQL 数据库为其他部…

python 深度学习 解决遇到的报错问题8

本篇继python 深度学习 解决遇到的报错问题7-CSDN博客 目录 一、OSError: [WinError 127] 找不到指定的程序。 Error loading "D:\my_ruanjian\conda-myenvs\deeplearning\lib\site-packages\torch\lib\caffe2_detectron_ops.dll" or one of its dependencies. 二、…

COCOS2DX3.17.2 Android升级targetSDK30问题解决方案

一、luajit不兼容问题 不兼容版本:【2.1.0-bate2、2.1.0-bate3都存在异常】 出问题系统:Android11;Android10的系统部分机型有问题,部分机型正常 异常点1:c调用lua接口,pushObjiect的时候crash 异常点2…

从Spring说起

一. Spring是什么 在前面的博文中,我们学会了SpringMVC的使用,可以完成一些基本功能的开发了,但是铁子们肯定有很多问题,下面来从Spring开始介绍,第一个问题,什么是Spring? Spring是包含了众多工具方法的IOC容器. Spring有两个核心思想--IOC和AOP,本章先来讲解IOC...... 1.1…

前端框架Vue学习 ——(一)快速入门

文章目录 Vue 介绍Vue快速入门 Vue 介绍 Vue 是一套前端框架,免除原生 JavaScript 中的 DOM 操作,简化书写。基于 MVVM (Model-View-ViewModel)思想,实现数据的双向绑定,将编程的关注点放在数据上。官网: https://v2.cn.vuejs.or…

区块链与教育:颠覆传统,引领未来

区块链与教育:颠覆传统,引领未来 摘要:本文将探讨区块链技术在教育领域的应用及其潜在影响。通过介绍区块链技术的基本原理、教育领域的现状,以及区块链技术在教育中的实际应用案例,我们将展望一个去中心化、安全可信…

想学计算机编程从什么学起?零基础如何自学计算机编程?中文编程开发语言工具箱之渐变标签组构件

想学计算机编程从什么学起?零基础如何自学计算机编程? 给大家分享一款中文编程工具,零基础轻松学编程,不需英语基础,编程工具可下载。 这款工具不但可以连接部分硬件,而且可以开发大型的软件,…

Prometheus+Grafana

一、Prometheus 获取配置文件 docker run -d -p 9090:9090 --name prometheus prom/prometheusmkdir -p /app/prometheusdocker cp prometheus:/etc/prometheus/prometheus.yml /app/prometheus/prometheus.yml停止并删除旧的容器,重新启动 docker run -d --name…

【漏洞复现】weblogic-10.3.6-‘wls-wsat‘-XMLDecoder反序列化(CVE-2017-10271)

感谢互联网提供分享知识与智慧&#xff0c;在法治的社会里&#xff0c;请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞扫描nacsweblogicScanner3、漏洞验证 说明内容漏洞编号CVE-2017-10271漏洞名称Weblogic < 10.3.…

【驱动开发】注册字符设备使用gpio设备树节点控制led三盏灯的亮灭

注册字符设备使用gpio设备树节点控制led三盏灯的亮灭 设备树&#xff1a; 头文件&#xff1a; #ifndef __HEAD_H__ #define __HEAD_H__ typedef struct {unsigned int MODER;unsigned int OTYPER;unsigned int OSPEEDR;unsigned int PUPDR;unsigned int IDR;unsigned int OD…

Redis7.x 高级篇

Redis7.x 高级篇 Redis版本发行时间Redis单线程说的是什么东西 Redis版本发行时间 Redis单线程说的是什么东西

强化学习的动态规划二

一、典型示例 考虑如下所示的44网格。 图1 非终端状态为S {1, 2, . . . , 14}。在每个状态下有四种可能的行为&#xff0c;A {up, down, right, left}&#xff0c;这些行为除了会将代理从网格上移走外&#xff0c;其他都会确定性地引起相应的状态转换。因此&#xff0c;例如&…

MYSQL 8.0 配置CDC(binlog)

CDC&#xff08;Change Data Capture&#xff09;即数据变更抓取&#xff0c;通过源端数据源开启CDC&#xff0c;ROMA Connect 可实现数据源的实时数据同步以及物理表的物理删除同步。这里介绍通过开启Binlog模式CDC功能。 注意&#xff1a;1、使用MYSQL8.0及以上版本。 2、不…

ok-解决qt5发布版本,直接运行exe缺少各种库的问题

已实验第二种方法可用。 工具&#xff1a;电脑必备、QT下的windeployqt Qt 官方开发环境使用的动态链接库方式&#xff0c;在发布生成的exe程序时&#xff0c;需要复制一大堆 dll&#xff0c;如果自己去复制dll&#xff0c;很可能丢三落四&#xff0c;导致exe在别的电脑里无法…

抵押贷款巨头 Mr. Cooper 遭受网络攻击,影响 IT 系统

导语 近日&#xff0c;美国抵押贷款巨头 Mr. Cooper 遭受了一次网络攻击&#xff0c;导致该公司的 IT 系统受到影响。这一事件引起了广泛的关注&#xff0c;使得 Mr. Cooper 的在线支付平台无法正常运行。本文将为大家详细介绍这次网络攻击事件的具体情况及其对用户和公司造成的…

题解:轮转数组及复杂度分析

文章目录 &#x1f349;前言&#x1f349;题目&#x1f34c;解法一&#x1f34c;解法二&#xff1a;以空间换时间&#x1f95d;补充&#xff1a;memmove &#x1f34c;解法三&#xff08;选看&#xff09; &#x1f349;前言 本文侧重对于复杂度的分析&#xff0c;题解为辅。 …

在CARLA中获取CARLA自动生成的全局路径规划

CARLA生成全局路径规划的代码在 carla/PythonAPI/carla/agents/navigation 在自己的carla客户端py文件中 from agents.navigation.basic_agent import BasicAgent # pylint: disableimport-error 如果是pycharm开发&#xff0c;要在pycharm的Settings - Project Structure…

0003Java安卓程序设计-springboot基于Android的学习生活交流APP

文章目录 **摘** **要**目 录系统设计开发环境 编程技术交流、源码分享、模板分享、网课教程 &#x1f427;裙&#xff1a;776871563 摘 要 网络的广泛应用给生活带来了十分的便利。所以把学习生活交流管理与现在网络相结合&#xff0c;利用java技术建设学习生活交流APP&…

1825_ChibiOS的OSLIB中的存储分配器

全部学习汇总&#xff1a; GreyZhang/g_ChibiOS: I found a new RTOS called ChibiOS and it seems interesting! (github.com) 1. 之前有点不是很理解什么是静态OS&#xff0c;从这里基本上得到了这个答案。所谓的静态&#xff0c;其实就是内核之中不会使用Free以及Malloc这样…