Pytorch学习:神经网络模块torch.nn.Module和torch.nn.Sequential

文章目录

    • 1. torch.nn.Module
      • 1.1 add_module(name,module)
      • 1.2 apply(fn)
      • 1.3 cpu()
      • 1.4 cuda(device=None)
      • 1.5 train()
      • 1.6 eval()
      • 1.7 state_dict()
    • 2. torch.nn.Sequential
      • 2.1 append
    • 3. torch.nn.functional.conv2d

1. torch.nn.Module

官方文档:torch.nn.Module
CLASS torch.nn.Module(*args, **kwargs)

  • 所有神经网络模块的基类。
  • 您的模型也应该对此类进行子类化。
  • 模块还可以包含其他模块,允许将它们嵌套在树结构中。您可以将子模块分配为常规属性:
  • training(bool)-布尔值表示此模块是处于训练模式还是评估模式。

定义一个模型

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))
  • 以这种方式分配的子模块将被注册,并且当您调用 to() 等时也将转换其参数。
    • to(device=None,dtype=None,non_blocking=False)
      device ( torch.device) – 该模块中参数和缓冲区所需的设备
    • to(dtype ,non_blocking=False)
      dtype ( torch.dtype) – 该模块中参数和缓冲区所需的浮点或复杂数据类型
    • to(tensor,non_blocking=False)
      张量( torch.Tensor ) – 张量,其数据类型和设备是该模块中所有参数和缓冲区所需的数据类型和设备

引用上面定义的模型,将模型转移到GPU上

# 创建模型
model = Model()# 定义设备 gpu1
gpu1 = torch.device("cuda:1")
model = model.to(gpu1)

1.1 add_module(name,module)

将子模块添加到当前模块。
可以使用给定的名称作为属性访问模块。

add_module(name,module)
主要参数:

  • name(str)-子模块的名称。可以使用给定的名称从此模块访问子模块。
  • module(Module)-要添加到模块的子模块。

在这里插入图片描述
添加一个卷积层

model.add_module("conv3", nn.Conv2d(20, 20, 5))

在这里插入图片描述

1.2 apply(fn)

将 fn 递归地应用于每个子模块(由 .children() 返回)以及self。
典型的用法包括初始化模型的参数(另请参见torch.nn.init)。

apply(fn)
主要参数:

  • fn( Module -> None)-应用于每个子模块的函数

将所有线性层的权重置为1

import torch
from torch import nn@torch.no_grad()
def init_weights(m):print(m)if type(m) == nn.Linear:m.weight.fill_(1.0)print(m.weight)net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2,2))
net.apply(init_weights)

在这里插入图片描述

1.3 cpu()

将所有模型参数和缓冲区移动到CPU。

device = torch.device("cpu")
model = model.to(device)

1.4 cuda(device=None)

将所有模型参数和缓冲区移动到GPU。

这也使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在GPU上,则应在构造优化器之前调用该函数。

cuda(device=None)
主要参数:

  • device(int,可选)-如果指定,所有参数将被复制到该设备

转移到GPU包括以下参数:

  1. 模型
  2. 损失函数
  3. 输入输出
# 创建模型
model = Model()# 将模型转移到GPU上
model = model.cuda()# 将损失函数转移到GPU上
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()# 将输入输出转移到GPU上
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()

另一种表示形式(通过 to(device) 来表示)

# 创建模型
model = Model()# 定义设备:如果有GPU,则在GPU上训练, 否则在CPU上训练
device = torch.device("cuda" if torch.cuda.is_available else "cpu")# 将模型转移到GPU上
model = model.to(device)# 将损失函数转移到GPU上
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)# 将输入输出转移到GPU上
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)

1.5 train()

将模块设置为训练模式。

这只对某些模块有任何影响。如受影响,请参阅特定模块在培训/评估模式下的行为详情,例如: Dropout 、 BatchNorm 等。

train(mode=True)
主要参数:

  • mode(bool)-是否设置训练模式( True )或评估模式( False )。默认值: True 。

1.6 eval()

将模块设置为评估模式。

这只对某些模块有任何影响。如受影响,请参阅特定模块在培训/评估模式下的行为详情,例如: Dropout 、 BatchNorm 等。

在进行模型测试的时候会用到。

1.7 state_dict()

返回一个字典,其中包含对模块整个状态的引用。

返回模型的关键字典。

model = Model()
print(model.state_dict().keys())

在这里插入图片描述
在保存模型的时候我们也可以直接保存模型的 state_dict()

model = Model()# 保存模型
# 另一种方式:torch.save(model, "model.pth")
torch.save(model.state_dict(), "model.pth")# 加载模型
model.load_state_dict(torch.load("model.pth"))

2. torch.nn.Sequential

顺序容器。模块将按照它们在构造函数中传递的顺序添加到它。

Sequential 的 forward() 方法接受任何输入并将其转发到它包含的第一个模块。然后,它将输出“链接”到每个后续模块的输入,最后返回最后一个模块的输出。

官方文档:torch.nn.Sequential
CLASS torch.nn.Sequential(*args: Module)

import torch
from torch import nn# 使用 Sequential 创建一个小型模型。运行 `model` 时、
# 输入将首先传递给 `Conv2d(1,20,5)`。输出
# `Conv2d(1,20,5)`的输出将作为第一个
# 第一个 `ReLU` 的输出将成为 `Conv2d(1,20,5)` 的输入。
# `Conv2d(20,64,5)` 的输入。最后
# `Conv2d(20,64,5)` 的输出将作为第二个 `ReLU` 的输入
model = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU())

在这里插入图片描述

2.1 append

append 在末尾追加给定块。

  • append(module)
    在末尾追加给定模块。
    在这里插入图片描述
def append(self, module):self.add_module(str(len(self)), module)return selfappend(model, nn.Conv2d(64, 64, 5))
append(model, nn.ReLU())
print(model)

在这里插入图片描述

3. torch.nn.functional.conv2d

对由多个输入平面组成的输入图像应用2D卷积。
卷积神经网络详解:csdn链接

官方文档:torch.nn.functional.conv2d
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
主要参数:

  • input:形状的输入张量,(minibatch, inchannels, iH, iW)。
  • weigh:卷积核权重,形状为 (out_channels, inchannels / groups, kH, kW)

默认参数:

  • bias:偏置,默认值: None。
  • stride:步幅,默认值:1。
  • padding:填充,默认值:0。
  • dilation :内核元素之间的间距。默认值:1。
  • groups:将输入拆分为组,in_channels 应被组数整除。默认值:1。

在这里插入图片描述
对上图卷积操作进行代码实现

import torch.nn.functional as Finput = torch.tensor([[0, 1, 2],[3, 4, 5],[6, 7, 8]], dtype=float32)
kernel = torch.tensor([[0, 1],[2, 3]], dtype=float32)# F.conv2d 输入维数为4维
# torch.reshape(input, shape)
# reshape(样本数,通道数,高度,宽度)
input = torch.reshape(input, (1, 1, 3, 3))
kernel = torch.reshape(kernel, (1, 1, 2, 2))output = F.conv2d(input, kernel, stride=1)
print(input.shape)
print(kernel.shape)
print(input)
print(kernel)
print(output)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/113256.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

沉浸式VR虚拟实景样板间降低了看房购房的难度

720 全景是一种以全景视角为特点的虚拟现实展示方式,它通过全景图像和虚拟现实技术,将用户带入一个仿佛置身其中的沉浸式体验中。720 全景可以应用于旅游、房地产、展览等多个领域,为用户提供更为直观、真实的体验。 在房地产领域&#xff0c…

node-red - 读写操作redis

node-red - 读写操作redis 一、前期准备二、node-red安装redis节点三、node-red操作使用redis节点3.1 redis-out节点 - 存储数据到redis3.2 redis-cmd节点 - 存储redis数据3.3 redis-in节点 - 查询redis数据 附录附录1:redis -out节点示例代码附录2:redi…

Java eight 解读流(Stream)、文件(File)、IO和异常处理的使用方法

目录 Java 流(Stream)、文件(File)和IO读取控制台输入读写文件FileInputStreamFileOutputStream Java目录 Java 异常处理 Java 流(Stream)、文件(File)和IO java.io 包几乎包含了所有操作输入、输出需要的类。所有这些流类代表了输入源和输出目标。 Java.io 包中的流支持很多种…

matlab-对数据集加噪声并实现tsne可视化

matlab-对数据集加噪声并实现tsne可视化 最近才知道,原来可以不用模型,也能实现对数据集数据的可视化。 **一、**以COIL-100数据集为例子。 问题: 前提:首先对COIL-100数据集根据角度0-175和180-255,分别划分成C1,C…

c++学习之vector的实现

在学习实现vector之前我们会看到对于库中的vector的实现,这里并非使用在学习string那样的定义方式,而是利用迭代器,也就是指针来实现的,这在功能的实现时极大的方便了我们。 那么我们就模仿库这样的方式实现我们呢经常会用到的一些…

PowerBuilder连接SQLITE3

PowerBuilder,一个古老的IDE,打算陆续发些相关的,也许还有人需要,内容可能涉及其他作者,但基本都是基于本人实践整理,如涉及归属,请联系. SQLite,轻型数据库,相对与PowerBuilder来说是个新事务,故发数来,以供参考. PB中使用OLE Microsoft OLE DB方式进行连接,如下 // Profile…

邮件群发的功能特性

自动切换IP登录多账户发送 保证第三方发件邮箱系统发送成功率 由于第三方免费邮箱如同个IP登录多个163账号会造成被屏蔽的问题,我们采用自动拨号vps的方式可全国多个地区自动拨号切换IP,自动保证每个账号同时只登录一个账号发送,更可以多种类型小号混合…

kafka+Kraft模式集群+安全认证

Kraft模式安全认证 前章内容聊到了Kafka的Kraft集群的配置及使用。本篇再来说说kafka的安全认证方面的配置,。 Kafka提供了多种方式来进行安全认证,包括身份认证、授权和加密传输。一些常用的Kafka安全认证方式: SSL/TLS:使用S…

数据库的基本概念

数据库 数据库由表集合组成,它是以一定的组织方式存储的相互有关的数据集合。 表:记录:行,字段(属性):列,以行列的形式就组成了表(数据存储在表中)。 关系数…

学术加油站|基于端到端性能的学习型基数估计器综合测评

编者按 本文系东北大学李俊虎所著,也是「 OceanBase 学术加油站」系列第 11 篇内容。 「李俊虎:东北大学计算机科学与工程学院在读硕士生,课题方向为数据库查询优化,致力于应用 AI 技术改进传统基数估计器,令数据库选…

【LeetCode】3. 无重复字符的最长子串

3. 无重复字符的最长子串(中等) 方法:滑动窗口 哈希表 思路 这道题主要用到思路是:滑动窗口 什么是滑动窗口? 其实就是一个队列,比如例题中的 abcabcbb,进入这个队列(窗口)为 ab…

版本控制 Git工具的使用

版本控制的概念: 版本控制(Revision control)是一种在开发的过程中用于管理我们对文件、目录或工程等内容的修改历史,方便查看更改历史记录,备份以便恢复以前的版本的软件工程技术。简单来说就是用于管理多人协同开发…

LeetCode--HOT100题(46)

目录 题目描述:114. 二叉树展开为链表(中等)题目接口解题思路代码 PS: 题目描述:114. 二叉树展开为链表(中等) 给你二叉树的根结点 root ,请你将它展开为一个单链表: 展开后的单链…

【Flutter】Flutter 使用 collection 优化集合操作

【Flutter】Flutter 使用 collection 优化集合操作 文章目录 一、前言二、安装和基本使用三、算法介绍四、如何定义相等性五、Iterable Zip 的使用六、优先队列的实现和应用七、包装器的使用八、完整示例九、总结 一、前言 大家好!我是小雨青年,今天我要…

windows 中pycharm中venv无法激活

1.用管理员身份打开Windows PowerShell 2.进入项目的:venv\Scripts 如:D: (1): cd .\project\venv\Scripts\ (2): 执行命令: Set-ExecutionPolicy RemoteSigned (3): 选择:Y (4): .\activate

树多选搜索查询,搜索后选中状态仍保留

<template><div class"half-transfer"><div class"el-transfer-panel"><div><el-checkbox v-model"selectAll" change"handleSelectAll">全部</el-checkbox></div><el-input v-model&qu…

Stable Diffusion WebUI 整合包

现在网络上出现的各种整合包只是整合了运行 Stable Diffusion WebUI&#xff08;以下简称为 SD-WebUI&#xff09;必需的 Python 和 Git 环境&#xff0c;并且预置好模型&#xff0c;有些整合包还添加了一些常用的插件&#xff0c;其实际与手动进行本地部署并没有区别。 不过&a…

【HMS Core】运动健康之睡眠问题小结

【关键词】 运动健康、睡眠 【问题1】睡眠状态的数据来源只能是手表和手环吗&#xff0c;是否可以从手机获取&#xff1f; 答&#xff1a;可以获取用手机记录睡眠的睡眠记录&#xff0c;如果睡眠时&#xff0c;手机有采集睡眠状态&#xff0c;那么也是可以获取。 【问题2】获…

设计模式 - 工厂模式

前言 假设你开了一家奶茶店&#xff0c;顾客可以点普通奶茶&#xff0c;珍珠奶茶&#xff0c;香芋奶茶和红枣奶茶 一.传统模式 传统模式下&#xff0c;顾客根据名字点单&#xff0c;你获取名字然后做出奶茶。 class MilkTea{string name; public:MilkTea* create(string TeaN…

2023年最新 Github Pages 使用手册

参考&#xff1a;GitHub Pages 快速入门 1、什么是 Github Pages GitHub Pages 是一项静态站点托管服务&#xff0c;它直接从 GitHub 上的仓库获取 HTML、CSS 和 JavaScript 文件&#xff0c;&#xff08;可选&#xff09;通过构建过程运行文件&#xff0c;然后发布网站。 可…