8、深入剖析PyTorch的state_dict、parameters、modules源码

文章目录

  • 1. 重要类
  • 2. 保存模型
  • 3. 代码测试

1. 重要类

  • container.py
  • nn.sequential
  • nn.modulelist
  • save_state_dict

2. 保存模型

pytorch官网教程

3. 代码测试

比较急,后续完善

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :ToTest01.py
# @Time      :2024/11/24 10:37
# @Author    :Jason Zhang
import torch
from torch import nn
from torch.nn import Moduleclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear1 = nn.Linear(2, 3)self.linear2 = nn.Linear(3, 4)self.batch_norm4 = nn.BatchNorm2d(4)def forward(self, x):x = self.linear1(x)x = self.linear2(x)return xif __name__ == "__main__":run_code = 0input_x = torch.randn((1, 2))test_model = MyModel()y = test_model(input_x)model_modules = test_model._modulesprint(f"*"*50)print(f"model_modules=\n{model_modules}")print(f"*"*50)linear1 = model_modules['linear1']print(f"*"*50)print(f"linear1={linear1}")print(f"*"*50)print(f"linear1.weight=\n{linear1.weight}")print(f"*"*50)print(f"linear1.weight.dtype={linear1.weight.dtype}")print(f"*"*50)test_model.to(torch.double)print(f"linear1.weight.dtype={linear1.weight.dtype}")print(f"*"*50)test_model.to(torch.float32)print(f"linear1.weight.dtype={linear1.weight.dtype}")print(f"*"*50)model_parameters = test_model._parametersprint(f"model_parameters={model_parameters}")print(f"*"*50)model_buffers = test_model._buffersprint(f"model_buffer={model_buffers}")print(f"*"*50)model_state_dict = test_model.state_dict()print(f"model_state_dict=\n{model_state_dict}")print(f"*"*50)model_state_dict_linear2 = test_model.state_dict()['linear2.weight']print(f"model_state_dict_linear2=\n{model_state_dict_linear2}")print(f"*"*50)model_named_para =list(test_model.named_parameters())print(f"model_named_para=\n{model_named_para}")print(f"*"*50)model_named_modules =list(test_model.named_modules())print(f"model_named_modules=\n{model_named_modules}")print(f"*"*50)model_named_buffers =list(test_model.named_buffers())print(f"model_named_buffers=\n{model_named_buffers}")print(f"*"*50)model_named_children =list(test_model.named_children())print(f"model_named_children=\n{model_named_children}")
  • 结果:
**************************************************
model_modules=
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
**************************************************
**************************************************
linear1=Linear(in_features=2, out_features=3, bias=True)
**************************************************
linear1.weight=
Parameter containing:
tensor([[-0.5518,  0.0687],[-0.7013,  0.4869],[-0.1157, -0.1287]], requires_grad=True)
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
linear1.weight.dtype=torch.float64
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
model_parameters=OrderedDict()
**************************************************
model_buffer=OrderedDict()
**************************************************
model_state_dict=
OrderedDict([('linear1.weight', tensor([[-0.5518,  0.0687],[-0.7013,  0.4869],[-0.1157, -0.1287]])), ('linear1.bias', tensor([-0.2915, -0.4807,  0.0071])), ('linear2.weight', tensor([[ 0.4185,  0.1556,  0.1371],[ 0.4751,  0.2029, -0.0679],[ 0.1264, -0.0288, -0.3661],[ 0.4423, -0.5370,  0.3930]])), ('linear2.bias', tensor([ 0.2746, -0.1798,  0.0218,  0.5465])), ('batch_norm4.weight', tensor([1., 1., 1., 1.])), ('batch_norm4.bias', tensor([0., 0., 0., 0.])), ('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))])
**************************************************
model_state_dict_linear2=
tensor([[ 0.4185,  0.1556,  0.1371],[ 0.4751,  0.2029, -0.0679],[ 0.1264, -0.0288, -0.3661],[ 0.4423, -0.5370,  0.3930]])
**************************************************
model_named_para=
[('linear1.weight', Parameter containing:
tensor([[-0.5518,  0.0687],[-0.7013,  0.4869],[-0.1157, -0.1287]], requires_grad=True)), ('linear1.bias', Parameter containing:
tensor([-0.2915, -0.4807,  0.0071], requires_grad=True)), ('linear2.weight', Parameter containing:
tensor([[ 0.4185,  0.1556,  0.1371],[ 0.4751,  0.2029, -0.0679],[ 0.1264, -0.0288, -0.3661],[ 0.4423, -0.5370,  0.3930]], requires_grad=True)), ('linear2.bias', Parameter containing:
tensor([ 0.2746, -0.1798,  0.0218,  0.5465], requires_grad=True)), ('batch_norm4.weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)), ('batch_norm4.bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]
**************************************************
model_named_modules=
[('', MyModel((linear1): Linear(in_features=2, out_features=3, bias=True)(linear2): Linear(in_features=3, out_features=4, bias=True)(batch_norm4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)), ('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]
**************************************************
model_named_buffers=
[('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))]
**************************************************
model_named_children=
[('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/478062.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计算机网络】网段划分

一、为什么有网段划分 IP地址 网络号(目标网络) 主机号(目标主机) 网络号: 保证相互连接的两个网段具有不同的标识 主机号: 同一网段内,主机之间具有相同的网络号,但是必须有不同的主机号 互联网中的每一台主机,都要隶属于某一个子网 -&…

机器学习周志华学习笔记-第5章<神经网络>

机器学习周志华学习笔记-第5章<神经网络> 卷王&#xff0c;请看目录 5模型的评估与选择5.1 神经元模型5.2 感知机与多层网络5.3 BP(误逆差)神经网络算法 5.4常见的神经网络5.4.1 RBF网络&#xff08;Radial Basis Function Network&#xff0c;径向基函数网络&#xff0…

MySQL数据库设计

数据库设计 数据库是用来存在数据的&#xff0c;需要设计合理的数据表来存放数据–能够完成数据的存储&#xff0c;同时能够方便的提取应该系统所需的数据 1. 数据库的设计流程 数据库是为应用系统服务的&#xff0c;数据库的数据存储也是由应用系统决定的 当我们进行应用系统开…

Spring Boot 3.x + OAuth 2.0:构建认证授权服务与资源服务器

Spring Boot 3.x OAuth 2.0&#xff1a;构建认证授权服务与资源服务器 前言 随着Spring Boot 3的发布&#xff0c;我们迎来了许多新特性和改进&#xff0c;其中包括对Spring Security和OAuth 2.0的更好支持。本文将详细介绍如何在Spring Boot 3.x版本中集成OAuth 2.0&#xf…

数据可视化复习2-绘制折线图+条形图(叠加条形图,并列条形图,水平条形图)+ 饼状图 + 直方图

目录 目录 一、绘制折线图 1.使用pyplot 2.使用numpy ​编辑 3.使用DataFrame ​编辑 二、绘制条形图&#xff08;柱状图&#xff09; 1.简单条形图 2.绘制叠加条形图 3.绘制并列条形图 4.水平条形图 ​编辑 三、绘制饼状图 四、绘制散点图和直方图 1.散点图 2…

logback 初探学习

logback 三大模块 记录器&#xff08;Logger&#xff09;、追加器&#xff08;Appender&#xff09;和布局&#xff08;Layout&#xff09; 配置文件外层最基本的标签如图示 xml中定义的就是这个三个东西下面进入学习 包引入参考springboot 官方文档 Logging :: Spring Boo…

Linux:自定义Shell

本文旨在通过自己完成一个简单的Shell来帮助理解命令行Shell这个程序。 目录 一、输出“提示” 二、获取输入 三、切割字符串 四、执行指令 1.子进程替换 2.内建指令 一、输出“提示” 这个项目基于虚拟机Ubuntu22.04.5实现。 打开终端界面如图所示。 其中。 之前&#x…

《图像梯度与常见算子全解析:原理、用法及效果展示》

简介:本文深入探讨图像梯度相关知识&#xff0c;详细介绍图像梯度是像素灰度值在不同方向的变化速度&#xff0c;并以 “pig.JPG” 图像为例&#xff0c;通过代码展示如何选取图像部分区域并分析其像素值以论证图像梯度与边缘信息的关联。接着全面阐述了 Sobel 算子&#xff0c…

项目进度计划表:详细的甘特图的制作步骤

甘特图&#xff08;Gantt chart&#xff09;&#xff0c;又称为横道图、条状图&#xff08;Bar chart&#xff09;&#xff0c;是一种用于管理时间和任务活动的工具。 甘特图由亨利劳伦斯甘特&#xff08;Henry Laurence Gantt&#xff09;发明&#xff0c;是一种通过条状图来…

A045-基于spring boot的个人博客系统的设计与实现

&#x1f64a;作者简介&#xff1a;在校研究生&#xff0c;拥有计算机专业的研究生开发团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339; 赠送计算机毕业设计600…

QT基础 编码问题 定时器 事件 绘图事件 keyPressEvent QT5.12.3环境 C++实现

一、编码问题 在计算机编程中&#xff0c;流&#xff08;Stream&#xff09;是一种抽象的概念&#xff0c;用于表示数据的输入或输出。根据处理数据的不同方式&#xff0c;流可以分为字节流&#xff08;Byte Stream&#xff09;和字符流&#xff08;Character Stream&#xff0…

Python爬虫项目 | 二、每日天气预报

文章目录 1.文章概要1.1 实现方法1.2 实现代码1.3 最终效果1.3.1 编辑器内打印显示效果实际应用效果 2.具体讲解2.1 使用的Python库2.2 代码说明2.2.1 获取天气预报信息2.2.2 获取当天日期信息&#xff0c;格式化输出2.2.3 调用函数&#xff0c;输出结果 2.3 过程展示 3 总结 1…

百度在下一盘大棋

这两天世界互联网大会在乌镇又召开了。 我看到一条新闻&#xff0c;今年世界互联网大会乌镇峰会发布“2024 年度中国互联网企业创新发展十大典型案例”&#xff0c;百度文心智能体平台入选。 这个智能体平台我最近也有所关注&#xff0c;接下来我就来讲讲它。 百度在下一盘大棋…

UG NX二次开发(C++)-UIStyler-指定平面的对象和参数获取

文章目录 1、前言2、在UG NX中创建平面和一个长方体,3、在UI Styler中创建一个UI界面4、在VS中创建一个工程4.1 创建并添加工程文件4.2 在Update_cb方法中添加选择平面的代码4.3 编译完成并测试效果1、前言 在采用NXOpen C++进行二次开发时,采用Menu/UIStyler是一种很常见的…

【软考】数据库

1. 数据模型 1.1 概念数据模型 概念数据模型一般用 E-R 图表示&#xff0c;常用术语如下&#xff1a; 实体&#xff1a;客观存在的事物&#xff0c;如&#xff1a;一个单位、一个职工、一个部门、一个项目。属性&#xff1a;学生实体有学号、姓名、出生日期等属性。码&#…

【强化学习的数学原理】第04课-值迭代与策略迭代-笔记

学习资料&#xff1a;bilibili 西湖大学赵世钰老师的【强化学习的数学原理】课程。链接&#xff1a;强化学习的数学原理 西湖大学 赵世钰 文章目录 一、值迭代算法二、策略迭代算法三、截断策略迭代算法四、本节课内容summary 一、值迭代算法 值迭代算法主要包括两部分。 第一…

jupyter notebook的 markdown相关技巧

目录 1 先选择为markdown类型 2 开关技巧 2.1 运行markdown 2.2 退出markdown显示效果 2.3 注意点&#xff1a;一定要 先选择为markdown类型 3 一些设置技巧 3.1 数学公式 3.2 制表 3.3 目录和列表 3.4 设置各种字体效果&#xff1a;加粗&#xff0c;斜体&#x…

Spring Boot3远程调用工具RestClient

Spring Boot3.2之后web模块提供了一个新的远程调用工具RestClient&#xff0c;它的使用比RestTemplate方便&#xff0c;开箱即用&#xff0c;不需要单独注入到容器之中&#xff0c;友好的rest风格调用。下面简单的介绍一下该工具的使用。 一、写几个rest风格测试接口 RestCont…

vscode可以编译通过c++项目,但头文件有红色波浪线的问题

1、打开 VSCode 的设置&#xff0c;可以通过快捷键 Ctrl Shift P 打开命令面板&#xff0c;然后搜索并选择 “C/C: Edit Configurations (JSON)” 命令&#xff0c;这将在 .vscode 文件夹中创建或修改 c_cpp_properties.json 文件 {"configurations": [{"name…

近源渗透|HID ATTACK从0到1

前言 对于“近源渗透”这一术语&#xff0c;相信大家已经不再感到陌生。它涉及通过伪装、社会工程学等手段&#xff0c;实地侵入企业办公区域&#xff0c;利用内部潜在的攻击面——例如Wi-Fi网络、RFID门禁、暴露的有线网口、USB接口等——获取关键信息&#xff0c;并以隐蔽的…