pytorch中nn.Sequential详解

1 nn.Sequential概述

1.1 nn.Sequential介绍

nn.Sequential是一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

因此,Sequential可以看成是有多个函数运算对象,串联成的神经网络,其返回的是Module类型的神经网络对象。

1.2 nn.Sequential的本质作用

与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。这就意味着我们可以利用nn.Sequential() 自定义自己的网络层。

示例代码:

from torch import nnclass net(nn.Module):def __init__(self, in_channel, out_channel):super(net, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channel, in_channel / 4, kernel_size=1),nn.BatchNorm2d(in_channel / 4),nn.ReLU())self.layer2 = nn.Sequential(nn.Conv2d(in_channel / 4, in_channel / 4),nn.BatchNorm2d(in_channel / 4),nn.ReLU())self.layer3 = nn.Sequential(nn.Conv2d(in_channel / 4, out_channel, kernel_size=1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x

上边的代码,我们通过nn.Sequential()将卷积层,BN层和激活函数层封装在一个层中,输入x经过卷积、BN和ReLU后直接输出激活函数作用之后的结果。

1.3 nn.Sequential源码

def __init__(self, *args):super(Sequential, self).__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)

nn.Sequential()首先判断接收的参数是否为OrderedDict类型,如果是的话,分别取出OrderedDict内每个元素的key(自定义的网络模块名)和value(网络模块),然后将其通过add_module方法添加到nn.Sequrntial()中。

    # NB: We can't really type check this function as the type of input# may change dynamically (as is tested in# TestScript.test_sequential_intermediary_types).  Cannot annotate# with Any as TorchScript expects a more precise typedef forward(self, input):for module in self:input = module(input)return input

 调用forward()方法进行前向传播时,for循环按照顺序遍历nn.Sequential()中存储的网络模块,并以此计算输出结果,并返回最终的计算结果。

 

1.3 nn.Sequential与其它容器的区别

2 使用nn.Sequential定义网络

2.1 顺序添加网络模块到容器中

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(28 * 28, 32),nn.ReLU(),nn.Linear(32, 10),nn.Softmax(dim=1)
)
print("model:", model)
print("model.parameters:", model.parameters)x_input = torch.randn(2, 28, 28, 1)
print("x_input:", x_input)
print("x_input.shape:", x_input.shape)y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential((0): Linear(in_features=784, out_features=32, bias=True)(1): ReLU()(2): Linear(in_features=32, out_features=10, bias=True)(3): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential((0): Linear(in_features=784, out_features=32, bias=True)(1): ReLU()(2): Linear(in_features=32, out_features=10, bias=True)(3): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.1127, 0.0652, 0.1399, 0.0973, 0.1085, 0.0859, 0.1193, 0.1048, 0.0865,0.0800],[0.0986, 0.0955, 0.0927, 0.0765, 0.0782, 0.1004, 0.1171, 0.1605, 0.0883,0.0922]], grad_fn=<SoftmaxBackward0>)

2.2 包含神经网络模块的OrderedDict传入容器中

import torch
import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
print("model:", model)
print("model.parameters:", model.parameters)x_input = torch.randn(2, 28, 28, 1)
print("x_input.shape:", x_input.shape)y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.0836, 0.1185, 0.1422, 0.0801, 0.0817, 0.0870, 0.0948, 0.1099, 0.1131,0.0892],[0.0772, 0.0933, 0.1312, 0.1135, 0.1214, 0.0736, 0.1461, 0.0711, 0.0908,0.0818]], grad_fn=<SoftmaxBackward0>)

3 nn.Sequential网络操作

3.1 索引查看子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
print("index0:", model[0])
print("index1:", model[1])
print("index2:", model[2])

运行代码显示:

index0: Linear(in_features=784, out_features=32, bias=True)
index1: ReLU()
index2: Linear(in_features=32, out_features=10, bias=True)

3.2 修改子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
model[1] = nn.Sigmoid()
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): Sigmoid()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)

3.3 添加子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
model.append(nn.Linear(10, 2))
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)(4): Linear(in_features=10, out_features=2, bias=True)
)

3.4 删除子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
del model[2]
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(softmax): Softmax(dim=1)
)

3.5 嵌套子模块

import torch.nn as nnseq_1 = nn.Sequential(nn.Linear(15, 10), nn.ReLU(), nn.Linear(10, 5))
seq_2 = nn.Sequential(nn.Linear(25, 15), nn.Sigmoid(), nn.Linear(15, 10))
seq_3 = nn.Sequential(seq_1, seq_2)
print(seq_3)

运行代码显示:

Sequential((0): Sequential((0): Linear(in_features=15, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=5, bias=True))(1): Sequential((0): Linear(in_features=25, out_features=15, bias=True)(1): Sigmoid()(2): Linear(in_features=15, out_features=10, bias=True))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/220974.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于“Python”的核心知识点整理大全26

目录 10.3.9 决定报告哪些错误 10.4 存储数据 10.4.1 使用 json.dump()和 json.load() number_writer.py number_reader.py 10.4.2 保存和读取用户生成的数据 对于用户生成的数据&#xff0c;使用json保存它们大有裨益&#xff0c;因为如果不以某种方式进行存储&#xf…

介绍Silabs一款zigbee芯片:MG21

1.1 EFR32MG21系列&#xff0c;这款芯片旨在高性能、低功耗、安全解决方案&#xff0c;符合IEEE802.15.4规范和ZigBee3.0协议标准&#xff0c;采用2.4G SoC&#xff0c;适用于ZigBee、Thread等应用 &#xff0c;96K的RAM数据存储器及高达1024k的闪存程序存储器. 1.2 典型应用…

HarmonyOS应用开发-手写板(二)

在前一篇手写板的文章中&#xff08;HarmonyOS应用开发-手写板-CSDN博客&#xff09;&#xff0c;我们通过使用Path实现了一个基本的手写板&#xff0c;但遗憾的是&#xff0c;无法保存所绘制的图像。在本文中&#xff0c;我们将采用canvas和Path2D来重新构建手写板应用。依然只…

新手上路:自动驾驶行业快速上手指南

文章目录 1.自动驾驶技术的发展1.1 工业革命驱动自动驾驶技术发展1.2 想象中的未来&#xff1a;科幻作品中的自动驾驶汽车1.3 自动驾驶技术萌芽与尝试1.4 百花争鸣&#xff1a;自动驾驶科技巨头与创业公司并进 2.个人开发者&#xff0c;如何玩转自动驾驶&#xff1f;2.1 灵活易…

MySQL数据库,触发器、窗口函数、公用表表达式

触发器 触发器是由事件来触发某个操作&#xff08;也包含INSERT、UPDATE、DELECT事件&#xff09;&#xff0c;如果定义了触发程序&#xff0c;当数据库执行这些语句时&#xff0c;就相当于事件发生了&#xff0c;就会自动激发触发器执行相应的操作。 当对数据表中的数据执行…

02.Git常用基本操作

一、基本配置 &#xff08;1&#xff09;打开Git Bash &#xff08;2&#xff09;配置姓名和邮箱 git config --global user.name "Your Name" git config --global user.email "Your email" 因为Git是分布式版本控制工具&#xff0c;所以每个用户都需要…

Pytorch nn.Linear()的基本用法与原理详解及全连接层简介

主要引用参考&#xff1a; https://blog.csdn.net/zhaohongfei_358/article/details/122797190 https://blog.csdn.net/weixin_43135178/article/details/118735850 nn.Linear的基本定义 nn.Linear定义一个神经网络的线性层&#xff0c;方法签名如下&#xff1a; torch.nn.Li…

Linux - 非root用户使用systemctl管理服务

文章目录 方式一 &#xff08;推荐&#xff09;1. 编辑sudoers文件&#xff1a;2. 设置服务文件权限&#xff1a;3. 启动和停止服务&#xff1a; 方式二1. 查看可用服务&#xff1a;2. 选择要配置的服务&#xff1a;3. 创建自定义服务文件&#xff1a;4. 重新加载systemd管理的…

el-date-picker限制选择7天内禁止内框选择

需求&#xff1a;elementPlus时间段选择框需要满足&#xff1a;①最多选7天时间。②不能手动输入。 <el-date-picker v-model"timeArrange" focus"timeEditable" :editable"false" type"datetimerange" range-separator"至&qu…

福德植保无人机工厂:创新科技与绿色农业的完美结合

亲爱的读者们&#xff0c;欢迎来到福德植保无人机工厂的世界。这里&#xff0c;科技与农业的完美结合为我们描绘出一幅未来农业的新篇章。福德植保无人机工厂作为行业的领军者&#xff0c;以其领先的无人机技术&#xff0c;创新的理念&#xff0c;为我们展示了一种全新的农业服…

使用Httpclient来替代客户端的jsonp跨域解决方案

最近接手一个项目&#xff0c;新项目需要调用老项目的接口&#xff0c;但是老项目和新项目不再同一个域名下&#xff0c;所以必须进行跨域调用了&#xff0c;但是老项目又不能进行任何修改&#xff0c;所以jsonp也无法解决了&#xff0c;于是想到了使用了Httpclient来进行服务端…

Vue简介

聚沙成塔每天进步一点点 ⭐ 专栏简介 Vue学习之旅的奇妙世界 欢迎大家来到 Vue 技能树参考资料专栏&#xff01;创建这个专栏的初衷是为了帮助大家更好地应对 Vue.js 技能树的学习。每篇文章都致力于提供清晰、深入的参考资料&#xff0c;让你能够更轻松、更自信地理解和掌握 …

[密码学]AES

advanced encryption standard&#xff0c;又名rijndael密码&#xff0c;为两位比利时数学家的名字组合。 分组为128bit&#xff0c;密钥为128/192/256bit可选&#xff0c;对应加密轮数10/12/14轮。 基本操作为四种&#xff1a; 字节代换&#xff08;subBytes transformatio…

PyQt6 QFontDialog字体对话框控件

锋哥原创的PyQt6视频教程&#xff1a; 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计50条视频&#xff0c;包括&#xff1a;2024版 PyQt6 Python桌面开发 视频教程(无废话版…

【docker】修改docker的数据目录

背景 主节点是分配了较少内存和存储的低配机器&#xff0c;因为我们系统的rancher是用docker镜像启动的&#xff0c;而rancher和docker的默认目录都放在/var/lib下面&#xff0c;而这个/var目录目前只分配10G的存储&#xff0c;导致节点存储报警。因此想修改docker的数据目录&…

中国高分辨率土壤侵蚀因子K

中国高分辨率土壤侵蚀因子K 土壤可蚀性因子&#xff08;K&#xff09;数据&#xff0c;基于多种土壤属性数据计算&#xff0c;所用数据包括土壤黏粒含量&#xff08;%&#xff09;、粉粒含量&#xff08;%&#xff09;、砂粒含量&#xff08;%&#xff09;、土壤有机碳含量&…

鸿蒙系统(HarmonyOS)之方舟框架(ArkUI)介绍

鸿蒙开发官网&#xff1a;HarmonyOS应用开发官网 - 华为HarmonyOS打造全场景新服务 方舟开发框架&#xff08;简称&#xff1a;ArkUI&#xff09;&#xff0c;是一套构建HarmonyOS应用界面的UI开发框架&#xff0c;它提供了极简的UI语法与包括UI组件、动画机制、事件交互等在内…

音视频参数介绍

一、视频参数概念 单个视频帧&#xff1a;可以简单地理解成为一张图片 单个视频帧主要的参数概念&#xff1a; 分辨率&#xff1a; 分辨率是指图像或显示器上像素的数量&#xff0c;通常用横向像素数乘以纵向像素数表示。例如&#xff0c;1920x1080 表示宽度为1920像素&…

多维时序 | MATLAB实现WOA-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测

多维时序 | MATLAB实现WOA-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测 目录 多维时序 | MATLAB实现WOA-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 MATLAB实现WOA-CNN-LST…

新增工具箱管理功能、重构网站证书管理功能,1Panel开源面板v1.9.0发布

2023年12月18日&#xff0c;现代化、开源的Linux服务器运维管理面板1Panel正式发布v1.9.0版本。 在这一版本中&#xff0c;1Panel引入了新的工具箱管理功能&#xff0c;包含Swap分区管理、Fail2Ban管理等功能。此外&#xff0c;1Panel针对网站证书管理功能进行了全面重构&…