动手学深度学习(三)深度学习计算

一、模型构造

1、继承Module类来构造模型来构造模型

class MLP(nn.Module):# 声明带有模型参数的层,这里声明了两个全连接层def __init__(self, **kwargs):# 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数# 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数paramssuper(MLP, self).__init__(**kwargs)self.hidden = nn.Linear(784, 256) # 隐藏层self.act = nn.ReLU()self.output = nn.Linear(256, 10)  # 输出层# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出def forward(self, x):a = self.act(self.hidden(x))return self.output(a)

2、Sequential类继承自Block

Sequential类它提供add函数来逐一添加串联的Module子类实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算

net = MySequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10), )
print(net)
net(X)

3、ModuleList

①定义

ModuleList 是 PyTorch 中的一种容器类,位于 torch.nn 模块下,专门用于存储多个子模块(即网络层)

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)

②ModuleList 和 Python 普通列表的区别

  • 注册模块ModuleList 中的所有子模块都会被注册为模型的一部分。PyTorch 会自动识别并将它们的参数纳入模型的训练和保存中。而普通的 Python 列表并不会注册其中的模块。
  • 参数追踪:使用 ModuleList 后,model.parameters() 可以追踪到列表中的所有模块参数。如果使用普通列表,模型中的这些层的参数将不会被自动管理。
(1)ModuleList
class Module_ModuleList(nn.Module):def __init__(self):super(Module_ModuleList, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10)])
(2)Python列表
class Module_List(nn.Module):def __init__(self):super(Module_List, self).__init__()self.linears = [nn.Linear(10, 10)]

由结果可以看出,使用了nn.ModuleList([nn.Linear(10, 10)]),自动注册了模块并进行参数追踪,而使用列表 [nn.Linear(10, 10)]定义的参数将不会被自动管理。

 

4、ModuleDict类

ModuleDictPyTorch 中 torch.nn 模块下的一个容器类专门用于存储多个子模块,并以字典的形式组织这些子模块。与 Python 的普通字典不同,ModuleDict 中的子模块会被自动注册为模型的一部分,这使得 PyTorch 可以自动追踪、保存和加载这些模块及其参数

net = nn.ModuleDict({'linear': nn.Linear(784, 256),'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问

二、模型参数的访问初始化和共享

init模块,它包含了多种模型初始化方法。

1、访问模型参数

net.named_parameters()

net.named_parameters() : PyTorch 中的一个方法,用于返回模型中所有可训练参数的名称和参数本身(权重和偏置)

print(type(net.named_parameters()))
for name, param in net.named_parameters():print(name, param.size())

② nn.Parameter

nn.Parameter:用于定义可以被优化(即可以通过梯度下降等算法进行训练)的参数。当你创建一个 nn.Parameter 对象时,它会自动注册到模型的参数列表中,这意味着它将被包含在模型的参数优化过程中。

class MyModel(nn.Module):def __init__(self, **kwargs):super(MyModel, self).__init__(**kwargs)self.weight1 = nn.Parameter(torch.rand(20, 20))self.weight2 = torch.rand(20, 20)def forward(self, x):pass

初始化权重的梯度是None,训练过程中回代才改变。

③参数的数值和梯度访问

param.data和param.grad访问和修改相关属性。

for name, param in net.named_parameters():if 'weight' in name:init.normal_(param, mean=0, std=0.01)print(name, param.data)print(name, param.grad)

2、初始化模型参数

①使用init中的方法初始化

下面代码分别是正态分布初始化和常数初始化。

init.normal_(param, mean=0, std=0.01)
init.constant_(param, val=0)

②自定义初始化

参数初始化时使用with torch.no_grad()来暂时禁用梯度计算,这对于初始化权重是有用的,因为我们不希望在初始化时计算梯度。

def init_weight_(tensor):with torch.no_grad():tensor.uniform_(-10, 10)tensor *= (tensor.abs() >= 5).float()

3、共享模型参数

当不同层指向的是同一个实例时,它们共享同样的权重。如果你初始化或更新其中一个层的参数,实际上这几个层都会受到映像。

linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear) 
print(net)
for name, param in net.named_parameters():init.constant_(param, val=3)print(name, param.data)

三、自定义层

1、不含模型参数的自定义层

class CenteredLayer(nn.Module):def __init__(self, **kwargs):super(CenteredLayer, self).__init__(**kwargs)def forward(self, x):return x - x.mean()
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

2、含模型参数的自定义层

class MyListDense(nn.Module):def __init__(self):super(MyListDense, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])self.params.append(nn.Parameter(torch.randn(4, 1)))def forward(self, x):for i in range(len(self.params)):x = torch.mm(x, self.params[i])return x
net = MyListDense()
print(net)

四、读取和存储

1、读写Tensor

torch.save():将张量存到指定文件中。

torch.load():载入指定文件中的张量。

y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

2、读写模型

state_dict()方法:

  • 保存模型的参数:通过 state_dict(),你可以将模型的参数提取出来并保存为一个字典,以便稍后加载或分享。
  • 加载模型的参数:可以通过 load_state_dict() 方法将保存的参数字典加载到模型中。
  • 检查模型的当前参数状态state_dict() 方便调试时检查模型的权重和偏置。
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()

PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/422331.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[数据集][目标检测]汽车头部尾部检测数据集VOC+YOLO格式5319张3类别

数据集制作单位:未来自主研究中心(FIRC) 版权单位:未来自主研究中心(FIRC) 版权声明:数据集仅仅供个人使用,不得在未授权情况下挂淘宝、咸鱼等交易网站公开售卖,由此引发的法律责任需自行承担 数据集格式:Pascal VOC格…

需求分析概述

为什么要进行需求分析呢? 笑话:富翁娶妻 某富翁想要娶老婆,有三个人选,富翁给了三个女孩各一千元,请 她们把房间装满。第一个女孩买了很多棉花,装满房间的1/2。第 二个女孩买了很多气球,装满…

Java多线程(一)

目录 Java多线程(一) 线程与进程基本介绍 并发和并行基本介绍 CPU调度基本介绍 主线程基本介绍 创建线程对象与相关方法 继承Thread类创建线程对象 多线程在内存中运行的原理 Thread类中常用的方法 Thread类中关于线程优先级的方法 守护线程与Thread类中…

Kafka【十三】消费者消费消息的偏移量

偏移量offset是消费者消费数据的一个非常重要的属性。默认情况下,消费者如果不指定消费主题数据的偏移量,那么消费者启动消费时,无论当前主题之前存储了多少历史数据,消费者只能从连接成功后当前主题最新的数据偏移位置读取&#…

信息安全数学基础(8)整数分解

前言 在信息安全数学基础中,整数分解是一个核心概念,它指的是将一个正整数表示为几个正整数的乘积的形式。虽然对于任何正整数,理论上都可以进行分解(除了1只能分解为1本身),但整数分解在密码学和信息安全中…

实战千问2大模型第三天——Qwen2-VL-7B(多模态)视频检测和批处理代码测试

画面描述:这个视频中,一位穿着蓝色西装的女性站在室内,背景中可以看到一些装饰品和植物。她双手交叉放在身前,面带微笑,似乎在进行一场演讲或主持活动。她的服装整洁,显得非常专业和自信。 一、简介 阿里通义千问开源新一代视觉语言模型Qwen2-VL。其中,Qwen2-VL-72B在大…

使用虚拟信用卡WildCard轻松订阅POE:全面解析平台功能与订阅方式

POE(Platform of Engagement)是一个由Quora推出的人工智能聊天平台,汇集了多个强大的AI聊天机器人,如GPT-4、Claude、Sage等。POE提供了一个简洁、统一的界面,让用户能够便捷地与不同的AI聊天模型进行互动。这种平台的…

先攒一波硬件,过几年再给电脑升级,靠谱吗?想啥呢?

前言 最近有小伙伴发来消息:我可以今年先买电脑的部分硬件,明年再买电脑的另一部分硬件,再组装起来不就是一台电脑了吗? 这确实是一个很好的办法。 我还记得大学有个室友,从大一每个月省吃俭用,攒下的钱…

Linux学习笔记(黑马程序员,前四章节)

第一章 快照 虚拟机快照: 通俗来说,在学习阶段我们无法避免的可能损坏Linux操作系统,如果损坏的话,重新安装一个Linux操作系统就会十分麻烦。VMware虚拟机支持为虚拟机制作快照。通过快照将当前虚拟机的状态保存下来,…

力扣100题——贪心算法

概述 贪心算法(Greedy Algorithm)是一种在解决问题时,按照某种标准在每一步都选择当前最优解(局部最优解)的算法。它期望通过一系列局部最优解的选择,最终能够得到全局最优解。 贪心算法的核心思想 贪心算…

Springboot中自定义监听器

一、监听器模式图 二、监听器三要素 广播器:用来发布事件 事件:需要被传播的消息 监听器:一个对象对一个事件的发生做出反应,这个对象就是事件监听器 三、监听器的实现方式 1、实现自定义事件 自定义事件需要继承ApplicationEv…

HashMap常用方法及底层原理

目录 一、什么是HashMap二、HashMap的链表与红黑树1、数据结构2、链表转为红黑树3、红黑树退化为链表 三、存储(put)操作四、读取(get)操作五、扩容(resize)操作六、HashMap的线程安全与顺序1、线程安全2、…

整型数组按个位值排序

题目描述 给定一个非空数组(列表),其元素数据类型为整型,请按照数组元素十进制最低位从小到大进行排序,十进制最低位相同的元司 相对位置保持不变。 当数组元素为负值时,十进制最低位等同于去除符号位后对应十进制值最低位。 输…

Facebook的虚拟现实计划:未来社交的全新视角

随着科技的不断进步,虚拟现实(VR)正逐步成为我们日常生活的一部分。作为全球领先的社交平台,Facebook正在大力投入虚拟现实技术,以重新定义社交互动的方式。本文将深入探讨Facebook的虚拟现实计划,分析其如…

Mycat2原理介绍

Mycat介绍 Mycat原理 Mycat 核心配置 Scheam.xml 逻辑数据库和节点对应关系配置Server.xml mycat的连接配置Rule.xml. 分片规则 自动分片auto-sharding-long,比如0-10000节点1 ,10001-20000节点2枚举分片sahrding-bt-intfile ,比如beijing节点1…

[数据集][目标检测]血细胞检测数据集VOC+YOLO格式2757张4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2757 标注数量(xml文件个数):2757 标注数量(txt文件个数):2757 标注…

【数据库】MySQL-基础篇-SQL

专栏文章索引:数据库 有问题可私聊:QQ:3375119339 目录 一、SQL通用语法 二、SQL分类 三、DDL 1.数据库操作 1.1 查询所有数据库 1.2 查询当前数据库 1.3 创建数据库 1)案例: 1.4 删除数据库 1.5 切换数据库…

discuz论坛3.4 截图粘贴图片发帖后显示不正常问题

处理方法 source\function 路径下修改function_discuzcode.php function bbcodeurl($url, $tags) 函数 if(!in_array(strtolower(substr($url, 0, 6)), array(http:/, https:, ftp://, rtsp:/, mms://,data:i) 这一句里增加 data:i 即可 function bbcodeurl($url,…

JAVA基础:抽象类,接口,instanceof,类关系,克隆

1 JDK中的包 JDK JRE 开发工具集(javac.exe) JRE JVM java类库 JVM java 虚拟机 jdk中自带了许多的包(类) , 常用的有 java.lang 该包中的类,不需要引用,可以直接使用。 例如&#xff1…

Redis面试题整理

Redis 1、Redis主从集群 1.1、搭建主从集群 单节点Redis的并发能力是有上限的,要进一步提高Redis的并发能力,就需要搭建主从集群,实现读写分离 1.2、主从同步原理 当主从第一次同步连接或断开重连时,从节点都会发送psync请求&…