7、深入剖析PyTorch nn.Module源码

文章目录

  • 1. 重要类
  • 2. add_modules
  • 3. Apply(fn)
  • 4. register_buffer
  • 5. nn.Parameters®ister_parameters
  • 6. 后续测试

1. 重要类

  • nn.module --> 所有神经网络的父类,自定义神经网络需要继承此类,并且自定义__init__,forward函数即可:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :MyModelNet.py
# @Time      :2024/11/20 13:38
# @Author    :Jason Zhang
import torch
from torch import nnclass NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork,self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsif __name__ == "__main__":run_code = 0x_row = 28x_column = 28x_total = x_row * x_columnx = torch.arange(x_total, dtype=torch.float).reshape((1, x_row, x_column))my_net = NeuralNetwork()y = my_net(x)print(f"y.shape={y.shape}")print(my_net)
  • 结果:
y.shape=torch.Size([1, 10])
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)

2. add_modules

通过add_modules在旧的网络里面添加新的网络

  • 重点: 用nn.ModuleList自带的insert,新的网络继承自老网络中,直接用按位置插入
  • python
import torch
from torch import nn
from pytorch_model_summary import summarytorch.manual_seed(2323)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.flatten = nn.Flatten()self.block = nn.ModuleList([nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 10)])def forward(self, x):x = self.flatten(x)for layer in self.block:x = layer(x)return xclass MyNewNet(MyModel):def __init__(self):super(MyNewNet, self).__init__()self.block.insert(2, nn.Linear(512, 256))  # 插入新层self.block.insert(3, nn.ReLU())  # 插入新的激活函数self.block.insert(4, nn.Linear(256, 512))  # 插入另一层self.block.insert(5, nn.ReLU())  # 插入激活函数if __name__ == "__main__":# 测试原始模型my_model = MyModel()print("Original Model:")print(summary(my_model, torch.ones((1, 28, 28))))# 测试新模型my_new_model = MyNewNet()print("\nNew Model:")print(summary(my_new_model, torch.ones((1, 28, 28))))
  • 结果:
Original Model:
-----------------------------------------------------------------------Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================Flatten-1            [1, 784]               0               0Linear-2            [1, 512]         401,920         401,920ReLU-3            [1, 512]               0               0Linear-4             [1, 10]           5,130           5,130
=======================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
-----------------------------------------------------------------------New Model:
-----------------------------------------------------------------------Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================Flatten-1            [1, 784]               0               0Linear-2            [1, 512]         401,920         401,920ReLU-3            [1, 512]               0               0Linear-4            [1, 256]         131,328         131,328ReLU-5            [1, 256]               0               0Linear-6            [1, 512]         131,584         131,584ReLU-7            [1, 512]               0               0Linear-8             [1, 10]           5,130           5,130
=======================================================================
Total params: 669,962
Trainable params: 669,962
Non-trainable params: 0
-----------------------------------------------------------------------

3. Apply(fn)

模型权重weight,bias 的初始化

  • python
import torch.nn as nn
import torchclass MyAwesomeModel(nn.Module):def __init__(self):super(MyAwesomeModel, self).__init__()self.fc1 = nn.Linear(3, 4)self.fc2 = nn.Linear(4, 5)self.fc3 = nn.Linear(5, 6)# 定义初始化函数
@torch.no_grad()
def init_weights(m):print(m)if type(m) == nn.Linear:m.weight.fill_(1.0)print(m.weight)# 创建神经网络实例
model = MyAwesomeModel()# 应用初始化权值函数到神经网络上
model.apply(init_weights)
  • 结果:
Linear(in_features=3, out_features=4, bias=True)
Parameter containing:
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]], requires_grad=True)
Linear(in_features=4, out_features=5, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]], requires_grad=True)
Linear(in_features=5, out_features=6, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]], requires_grad=True)
MyAwesomeModel((fc1): Linear(in_features=3, out_features=4, bias=True)(fc2): Linear(in_features=4, out_features=5, bias=True)(fc3): Linear(in_features=5, out_features=6, bias=True)
)Process finished with exit code 0

4. register_buffer

将模型中添加常数项。比如加1

  • python:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :RegisterBuffer.py
# @Time      :2024/11/23 19:21
# @Author    :Jason Zhang
import torch
from torch import nnclass MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.register_buffer("my_buffer_a", torch.ones(2, 3))def forward(self, x):x = x + self.my_buffer_areturn xif __name__ == "__main__":run_code = 0my_test = MyNet()in_x = torch.arange(6).reshape((2, 3))y = my_test(in_x)print(f"x=\n{in_x}")print(f"y=\n{y}")
  • 结果:
x=
tensor([[0, 1, 2],[3, 4, 5]])
y=
tensor([[1., 2., 3.],[4., 5., 6.]])

5. nn.Parameters&register_parameters

  • python
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :ParameterTest.py
# @Time      :2024/11/23 19:37
# @Author    :Jason Zhang
import torch
from torch import nnclass MyModule(nn.Module):def __init__(self, in_size, out_size):self.in_size = in_sizeself.out_size = out_sizesuper(MyModule, self).__init__()self.test = torch.rand(self.in_size, self.out_size)self.linear = nn.Linear(self.in_size, self.out_size)def forward(self, x):x = self.linear(x)return xclass MyModuleRegister(nn.Module):def __init__(self, in_size, out_size):self.in_size = in_sizeself.out_size = out_sizesuper(MyModuleRegister, self).__init__()self.test = torch.rand(self.in_size, self.out_size)self.linear = nn.Linear(self.in_size, self.out_size)def forward(self, x):x = self.linear(x)return xclass MyModulePara(nn.Module):def __init__(self, in_size, out_size):self.in_size = in_sizeself.out_size = out_sizesuper(MyModulePara, self).__init__()self.test = nn.Parameter(torch.rand(self.in_size, self.out_size))self.linear = nn.Linear(self.in_size, self.out_size)def forward(self, x):x = self.linear(x)return xif __name__ == "__main__":run_code = 0test_in = 4test_out = 6my_test = MyModule(test_in, test_out)my_test_para = MyModulePara(test_in, test_out)test_list = list(my_test.named_parameters())test_list_para = list(my_test_para.named_parameters())my_test_register = MyModuleRegister(test_in, test_out)para_register = nn.Parameter(torch.rand(test_in, test_out))my_test_register.register_parameter('para_add_register', para_register)test_list_para_register = list(my_test_register.named_parameters())print(f"*" * 50)print(f"test_list=\n{test_list}")print(f"*" * 50)print(f"*" * 50)print(f"test_list_para=\n{test_list_para}")print(f"*" * 50)print(f"*" * 50)print(f"test_list_para_register=\n{test_list_para_register}")print(f"*" * 50)
  • 结果:
**************************************************
test_list=
[('linear.weight', Parameter containing:
tensor([[ 0.3805, -0.3368,  0.2348,  0.4525],[-0.4557, -0.3344,  0.1368, -0.3471],[-0.3961,  0.3302,  0.1904, -0.0111],[ 0.4542, -0.3325, -0.3782,  0.0376],[ 0.2083, -0.3113, -0.3447, -0.1503],[ 0.0343,  0.0410, -0.4216, -0.4793]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([-0.3465, -0.4510,  0.4919,  0.1967, -0.1366, -0.2496],requires_grad=True))]
**************************************************
**************************************************
test_list_para=
[('test', Parameter containing:
tensor([[0.1353, 0.9934, 0.0462, 0.2103, 0.3410, 0.0814],[0.7509, 0.2573, 0.8030, 0.0952, 0.1381, 0.5360],[0.1972, 0.1241, 0.5597, 0.2691, 0.3226, 0.0660],[0.3333, 0.8031, 0.9226, 0.4290, 0.3660, 0.6159]], requires_grad=True)), ('linear.weight', Parameter containing:
tensor([[-0.0633, -0.4030, -0.4962,  0.1928],[-0.1707,  0.2259,  0.0373, -0.0317],[ 0.4523,  0.2439, -0.1376, -0.3323],[ 0.3215,  0.1283,  0.0729,  0.3912],[ 0.0262, -0.1087,  0.4721, -0.1661],[-0.1055, -0.2199, -0.4974, -0.3444]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([ 0.3702, -0.0142, -0.2098, -0.0910, -0.2323, -0.0546],requires_grad=True))]
**************************************************
**************************************************
test_list_para_register=
[('para_add_register', Parameter containing:
tensor([[0.2428, 0.1388, 0.6612, 0.4215, 0.0215, 0.2618],[0.4234, 0.0160, 0.8947, 0.4784, 0.4403, 0.4800],[0.8845, 0.1469, 0.6894, 0.7050, 0.5911, 0.7702],[0.7694, 0.0491, 0.3583, 0.4451, 0.2282, 0.4293]], requires_grad=True)), ('linear.weight', Parameter containing:
tensor([[ 0.1358, -0.4704, -0.4181, -0.4504],[ 0.0903,  0.3235, -0.3164, -0.4163],[ 0.1342,  0.3108,  0.0612, -0.2910],[ 0.3527,  0.3397, -0.0414, -0.0408],[-0.4877,  0.1925, -0.2912, -0.2239],[-0.0081, -0.1730,  0.0921, -0.4210]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([-0.2194,  0.2233, -0.4950, -0.3260, -0.0206, -0.0197],requires_grad=True))]
**************************************************

6. 后续测试

  • register_module
  • get_submodule
  • get_parameter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/477519.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序包之加农炮游戏

微信小程序 - 气球射击游戏 项目简介 这是一个简单有趣的微信小程序射击游戏。玩家通过控制屏幕底部的加农炮,射击从上方降落的蓝色气球。游戏考验玩家的反应能力和瞄准技巧。 游戏规则 点击屏幕任意位置发射炮弹大炮会自动对准点击位置击中气球获得10分如果气球触…

autogen+ollama+litellm实现本地部署多代理智能体

autogen 是一个专门为大语言模型 (LLMs) 驱动的自治代理 (autonomous agents) 设计的 Python 库,由 Microsoft 开发和维护。它通过高度模块化和可扩展的架构,支持用户快速构建和运行多代理系统,这些代理可以在没有明确人类干预的情况下协作完成复杂任务。AutoGen 支持以最少…

分公司如何纳税

分公司不进行纳税由总公司汇总纳税“子公司具有法人资格,依法独立承担民事责任;分公司不具有法人资格,其民事责任由公司承担。”企业设立分支机构,使其不具有法人资格,且不实行独立核算,则可由总公司汇总缴纳企业所得税…

HashMap源码详解

提醒你一下,用电脑或者平板看,手机屏幕小,图片会看不清楚,源码不方便看 基础前置 看该篇文章之前先看看这篇基础文章 HashMap底层原理-CSDN博客 先来看HashMap里面的一些参数。 1.DEFAULT_INITIAL_CAPACITY 默认的数组初…

09 —— Webpack搭建开发环境

搭建开发环境 —— 使用webpack-dev-server 启动Web服务,自动检测代码变化,有变化后会自动重新打包,热更新到网页(代码变化后,直接替换变化的代码,自动更新网页,不用手动刷新网页) …

动态反馈控制器(DFC)和 服务率控制器(SRC);服务率和到达率简单理解

目录 服务率和到达率简单理解 服务率 到达率 排队论中的应用 论文解析:队列等待成本动态感知控制模型 动态反馈和队列等待成本意识: 服务速率调整算法: 动态反馈控制器(DFC)和 服务率控制器(SRC) SRC公式4的原理 算力资源分配系统中的调整消耗 举例说明 服务…

九、FOC原理详解

1、FOC简介 FOC(field-oriented control)为磁场定向控制,又称为矢量控制(vectorcontrol),是目前无刷直流电机(BLDC)和永磁同步电机(PMSM)高效控制的最佳选择…

web-03

CSS回顾 选择器 标签选择器 标签{}ID选择器 标签中定义ID属性。 #ID值{}类选择器 标签中使用class属性 .类名{}关于DIV/span div任意的大小的长方形,大小css: width, height控制。—换行 span-- 一行内 CSS常用属性 width/height 宽度/高度 定义&…

Excel求和如何过滤错误值

一、问题的提出 平时,我们在使用Excel时,最常用的功能就是求和了,一说到求和你可能想到用sum函数,但是如果sum的求和区域有#value #Div等错误值怎么办?如下图,记算C列中工资的总和。 直接用肯定会报错&…

蓝桥杯每日真题 - 第22天

题目:(卡片) 题目描述(12届 C&C B组B题) 解题思路: 该问题要求用数字卡片从 1 开始拼出整数,直到某一时刻不能拼出时停止。要确定拼到哪个最大整数,需要统计 每个数字“1”被用…

Ubuntu20.04 Rk3588 交叉编译ffmpeg7.0

firefly 公司出的rk3588的设备,其中已经安装了gcc 交叉编译工具,系统版本是Ubuntu20.04。 使用Ubuntu20.04 交叉编译ffmpeg_ubuntu下配置ffmpeg交叉编译器为arm-linux-gnueabihf-gcc-CSDN博客文章浏览阅读541次。ubuntu20.04 交叉编译ffmpeg_ubuntu下配…

python代码制作数据集的测试和数据质量检测思路

前言 本文指的数据集为通用数据集,并不单是给机器学习领域使用。包含科研和工业领域需要自己制作数据集的。 首先,在制作大型数据集时,代码错误和数据问题可能会非常复杂。 前期逻辑总是简单的,库库一顿写,等排查的时…

Windows系统编程 - 进程遍历

文章目录 前言进程的遍历CreateToolhelp32SnapshotProcess32FirstProcess32Next进程遍历 总结 前言 各位师傅好,我是qmx_07,今天给大家讲解进程遍历的相关知识点 进程的遍历 快照:使用vmware虚拟机的时候,经常需要配置环境服务…

【GD32】(三) ISP基本使用

0 前言 有一块GD32的板子不知道为啥用着用着就下载不了程序了,没办法,只能另寻他法。作为STM32的平替,GD32的功能和STM32基本是一致的,所以也可以使用ISP来下载程序。于是就开始复活这块板子。 1 BOOT模式 对于熟悉STM32开发的人…

【SKFramework框架核心模块】3-2、音频管理模块

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享QQ群:398291828小红书小破站 大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【Unity3D框架】SKFramework框架完全教程《全…

开源项目-如何更好的参与开源项目开发

开源之谜-提升自我核心竞争力 一、寻找适合自己的开源项目二、像坐牢一样闭关修炼三、最后的实践 开源代码对所有人开放,开发者可以基于现有代码进行扩展和创新,而不是从零开始,参与开源项目可以提升自我的技术能力,丰富个人的经历…

利用c语言详细介绍下插入排序

插入排序,被称为直接插入排序。它的基本思想是将一个记录插入到已经排好序的有序表中,从而一个新的、记录数增 1 的有序表。 一、图文介绍 我们还是使用数组【10,5,3,20,1],排序使用升序的方式&…

STL——string类常用接口说明

目录 一、string类的介绍 二、string类常用接口的使用说明 1.成员函数 ​编辑 2.迭代器 3.容量 一、string类的介绍 下面是string类的文档对string类的介绍 1.string类是表示字符序列的对象 2.标准字符串类通过类似于标准字符容器的接口为此类对象提供支持&#xff0c…

Excel的图表使用和导出准备

目的 导出Excel图表是很多软件要求的功能之一,那如何导出Excel图表呢?或者说如何使用Excel图表。 一种方法是软件生成图片,然后把图片写到Excel上,这种方式,因为格式种种原因,导出的图片不漂亮&#xff0c…

2024年亚太地区数学建模大赛A题-复杂场景下水下图像增强技术的研究

复杂场景下水下图像增强技术的研究 对于海洋勘探来说,清晰、高质量的水下图像是深海地形测量和海底资源调查的关键。然而,在复杂的水下环境中,由于光在水中传播过程中的吸收、散射等现象,导致图像质量下降,导致模糊、…