【现代深度学习技术】深度学习计算 | 延后初始化自定义层

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、延后初始化
      • 实例化网络
    • 二、自定义层
      • (一)不带参数的层
      • (二)带参数的层
    • 小结


一、延后初始化

  到目前为止,我们忽略了建立网络时需要做的以下这些事情:

  • 我们定义了网络架构,但没有指定输入维度。
  • 我们添加层时没有指定前一层的输出维度。
  • 我们在初始化参数时,甚至没有足够的信息来确定模型应该包含多少参数。

  有些读者可能会对我们的代码能运行感到惊讶。 毕竟,深度学习框架无法判断网络的输入维度是什么。 这里的诀窍是框架的延后初始化(defers initialization), 即直到数据第一次通过模型传递时,框架才会动态地推断出每个层的大小。

  在以后,当使用卷积神经网络时, 由于输入维度(即图像的分辨率)将影响每个后续层的维数, 有了该技术将更加方便。 现在我们在编写代码时无须知道维度是什么就可以设置参数, 这种能力可以大大简化定义和修改模型的任务。 接下来,我们将更深入地研究初始化机制。

实例化网络

  首先,让我们实例化一个多层感知机。此时,因为输入维数是未知的,所以网络不可能知道输入层权重的维数。 因此,框架尚未初始化任何参数,我们通过尝试访问以下参数进行确认。

  接下来让我们将数据通过网络,最终使框架初始化参数。

  一旦我们知道输入维数是20,框架可以通过代入值20来识别第一层权重矩阵的形状。 识别出第一层的形状后,框架处理第二层,依此类推,直到所有形状都已知为止。 注意,在这种情况下,只有第一层需要延迟初始化,但是框架仍是按顺序初始化的。 等到知道了所有的参数形状,框架就可以初始化参数。

二、自定义层

  深度学习成功背后的一个因素是神经网络的灵活性:我们可以用创造性的方式组合不同的层,从而设计出适用于各种任务的架构。例如,研究人员发明了专门用于处理图像、文本、序列数据和执行动态规划的层。有时我们会遇到或要自己发明一个现在在深度学习框架中还不存在的层。在这些情况下,必须构建自定义层。本节将展示如何构建自定义层。

(一)不带参数的层

  首先,我们构造一个没有任何参数的自定义层。回忆一下在【现代深度学习技术】深度学习计算 | 层和块 对块的介绍,这应该看起来很眼熟。下面的CenteredLayer类要从其输入中减去均值。要构建它,我们只需继承基础层类并实现前向传播功能。

import torch
import torch.nn.functional as F
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()

  让我们向该层提供一些数据,验证它是否能按预期工作。

layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

在这里插入图片描述

  现在,我们可以将层作为组件合并到更复杂的模型中。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

  作为额外的健全性检查,我们可以在向该网络发送随机数据后,检查均值是否为0。由于我们处理的是浮点数,因为存储精度的原因,我们仍然可能会看到一个非常小的非零数。

Y = net(torch.rand(4, 8))
Y.mean()

在这里插入图片描述

(二)带参数的层

  以上我们知道了如何定义简单的层,下面我们继续定义具有参数的层,这些参数可以通过训练进行调整。我们可以使用内置函数来创建参数,这些函数提供一些基本的管理功能。比如管理访问、初始化、共享、保存和加载模型参数。这样做的好处之一是:我们不需要为每个自定义层编写自定义的序列化程序。

  现在,让我们实现自定义版本的全连接层。回想一下,该层需要两个参数,一个用于表示权重,另一个用于表示偏置项。在此实现中,我们使用修正线性单元作为激活函数。该层需要输入参数:in_unitsunits,分别表示输入数和输出数。

class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units))self.bias = nn.Parameter(torch.randn(units,))def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.datareturn F.relu(linear)

  接下来,我们实例化MyLinear类并访问其模型参数。

linear = MyLinear(5, 3)
linear.weight

在这里插入图片描述

  我们可以使用自定义层直接执行前向传播计算。

linear(torch.rand(2, 5))

在这里插入图片描述

  我们还可以使用自定义层构建模型,就像使用内置的全连接层一样使用自定义层。

net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

在这里插入图片描述

小结

  • 延后初始化使框架能够自动推断参数形状,使修改模型架构变得容易,避免了一些常见的错误。
  • 我们可以通过模型传递数据,使框架最终初始化参数。
  • 我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。
  • 在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。
  • 层可以有局部参数,这些参数可以通过内置函数创建。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12992.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Immutable设计 SimpleDateFormat DateTimeFormatter

专栏系列文章地址:https://blog.csdn.net/qq_26437925/article/details/145290162 本文目标: 理解不可变设计模式,时间format有线程安全要求的注意使用DateTimeFormatter 目录 ImmutableSimpleDateFormat 非线程安全可以synchronized解决&a…

Got socket exception during request. It might be caused by SSL misconfiguration

引入xutils3依赖,结果包找不到 maven里面添加阿里云镜像 核心 maven { url uri("https://maven.aliyun.com/nexus/content/groups/public/") }repositories {google()maven { url uri("https://maven.aliyun.com/nexus/content/groups/public/"…

RocketMQ实战—4.消息零丢失的方案

大纲 1.全链路分析为什么用户支付完成后却没有收到红包 2.RocketMQ的事务消息机制实现发送消息零丢失 3.RocketMQ事务消息机制的底层实现原理 4.是否可以通过同步重试方案来代替事务消息方案来实现发送消息零丢失 5.使用RocketMQ事务消息的代码案例细节 6.同步刷盘Raft协…

【Elasticsearch】date range聚合

好的,继续之前的示例: json ] } } } } 4.3 自定义键(key) 通过为每个范围指定一个唯一的键(key),可以在结果中更方便地引用每个范围。这在使用keyed参数将结果以键值对形式返回时尤其有用。 j…

【R语言】获取数据

R语言自带2种数据存储格式:*.RData和*.rds。 这两者的区别是:前者既可以存储数据,也可以存储当前工作空间中的所有变量,属于非标准化存储;后者仅用于存储单个R对象,且存储时可以创建标准化档案&#xff0c…

[leetcode]双指针算法的使用

零.参考文章 双指针技术在数组和链表问题中的应用解析-CSDN博客 一.使用情况 双指针即是在有序数组的情况下,我们通过两个指针在遍历的过程中进行标记,对满足条件的进行处理,直至遍历完整个数组。 二.举个例子 2.1小人过河问题&#xff…

自指学习:AGI的元认知突破

文章目录 引言:从模式识别到认知革命一、自指学习的理论框架1.1 自指系统的数学定义1.2 认知架构的三重反射1.3 与传统元学习的本质区别二、元认知突破的技术路径2.1 自指神经网络架构2.2 认知效能评价体系2.3 知识表示的革命三、实现突破的关键挑战3.1 认知闭环的稳定性3.2 计…

C++ 入门速通-第5章【黑马】

内容来源于:黑马 集成开发环境:CLion 先前学习完了C第1章的内容: C 入门速通-第1章【黑马】-CSDN博客 C 入门速通-第2章【黑马】-CSDN博客 C 入门速通-第3章【黑马】-CSDN博客 C 入门速通-第4章【黑马】-CSDN博客 下面继续学习第5章&…

hot100(7)

61.31. 下一个排列 - 力扣(LeetCode) 数组问题,下一个更大的排列 题解:31. 下一个排列题解 - 力扣(LeetCode) (1)从后向前找到一个相邻的升序对(i,j),此时…

图像分类与目标检测算法

在计算机视觉领域,图像分类与目标检测是两项至关重要的技术。它们通过对图像进行深入解析和理解,为各种应用场景提供了强大的支持。本文将详细介绍这两项技术的算法原理、技术进展以及当前的落地应用。 一、图像分类算法 图像分类是指将输入的图像划分为…

记录一次-Rancher通过UI-Create Custom- RKE2的BUG

一、下游集群 当你的下游集群使用Mysql外部数据库时,会报错: **他会检查ETCD。 但因为用的是Mysql外部数据库,这个就太奇怪了,而且这个检测不过,集群是咩办法被管理的。 二、如果不选择etcd,就选择控制面。 在rke2-…

SpringUI Web高端动态交互元件库

Axure Web高端动态交互元件库是一个专为Web设计与开发领域设计的高质量资源集合,旨在加速原型设计和开发流程。以下是关于这个元件库的详细介绍: 一、概述 Axure Web高端动态交互元件库是一个集成了多种预制、高质量交互组件的工具集合。这些组件经过精…

MySQL表的CURD

目录 一、Create 1.1单行数据全列插入 1.2多行数据指定列插入 1.3插入否则更新 1.4替换 2.Retrieve 2.1 select列 2.1.1全列查询 2.1.2指定列查询 2.1.3查询字段为表达式 2.1.4为查询结果指定别名 2.1.5结果去重 2.2where条件 2.3结果排序 2.4筛选分页结果 三…

文字加持:让 OpenCV 轻松在图像中插上文字

前言 在很多图像处理任务中,我们不仅需要提取图像信息,还希望在图像上加上一些文字,或是标注,或是动态展示。正如在一幅画上添加一个标语,或者在一个视频上加上动态字幕,cv2.putText 就是这个“文字魔术师”,它能让我们的图像从“沉默寡言”变得生动有趣。 今天,我们…

(9)gdb 笔记(2):查看断点 info b,删除断点 delete 3,回溯 bt,

(11) 查看断点 info b: # info b举例: (12)删除断点 delete 2 或者删除所有断点: # 1. 删除指定的断点 delete 3 # 2. 删除所有断点 delete 回车,之后输入 y 确认删除所有断点 举…

游戏引擎学习第88天

仓库:https://gitee.com/mrxiao_com/2d_game_2 调查碰撞检测器中的可能错误 在今天的目标是解决一个可能存在的碰撞检测器中的错误。之前有人提到在检测器中可能有一个拼写错误,具体来说是在测试某个变量时,由于引入了一个新的变量而没有正确地使用它&…

【2025】camunda API接口介绍以及REST接口使用(3)

前言 在前面的两篇文章我们介绍了Camunda的web端和camunda-modeler的使用。这篇文章主要介绍camunda结合springboot进行使用,以及相关api介绍。 该专栏主要为介绍camunda的学习和使用 🍅【2024】Camunda常用功能基本详细介绍和使用-下(1&…

Java高频面试之SE-17

hello啊,各位观众姥爷们!!!本牛马baby今天又来了!哈哈哈哈哈嗝🐶 Java缓冲区溢出,如何解决? 在 Java 中,缓冲区溢出 (Buffer Overflow) 虽然不是像 C/C 中那样直接可见…

用 Python 绘制爱心形状的简单教程

1. 引言 在本教程中,我们将学习如何使用 Python 和 Matplotlib 库来绘制一个简单的爱心形状。这是一个有趣且简单的项目,适合初学者练习图形绘制和数据可视化。 2. 环境准备 首先,确保您的系统上安装了 Python 和 Matplotlib 库。如果还未…

107,【7】buuctf web [CISCN2019 华北赛区 Day2 Web1]Hack World

这次先不进入靶场 看到红框里面的话就想先看看uuid是啥 定义与概念 UUID 是 Universally Unique Identifier 的缩写,即通用唯一识别码。它是一种由数字和字母组成的 128 位标识符,在理论上可以保证在全球范围内的唯一性。UUID 的设计目的是让分布式系…