transformer架构解析{前馈全连接层,规范化层,子层(残差)连接结构}(含代码)-4

目录

前言

前馈全连接层

学习目标

 什么是前馈全连接层

前馈全连接层的作用

 前馈全连接层代码实现

规范化层

学习目标

规范化层的作用

 规范化层的代码实现

子层(残差)连接结构

学习目标

什么是子层(残差)连接结构

子层连接结构代码实现 


前言

        我们之前学习了输入层(词嵌入层(经过词向量编码),位置编码(通过词位置信息向量和词特征矩阵得到))。注意力机制(注意力计算规则,自注意力和注意力区别,注意力机制,多头注意力机制)

前馈全连接层

学习目标

        了解什么是前馈全连接层及其作用

        掌握前馈全连接层的实现过程

 什么是前馈全连接层

        在transformer中前馈全连接层就是具有两层线性层的全连接网络

前馈全连接层的作用

        考虑到注意力机制可能对复杂过程的拟合程度不够,通过增加两层网络来增强模型能力

 前馈全连接层代码实现

#前馈全连接层的代码分析
class PositionwiseFeedForward(nn.Module):def __init__(self, d_model,d_ff, dropout = 0.1):#输入的参数分别是:d_model词嵌入的维度,d_ff全连接网络的中间层维度,dropout=0.1置零比率super(PositionwiseFeedForward,self).__init__()#首先先使用nn实例化两个线性层对象self.w1,self.w2#参数是d_mdoel,d_ff,d_modelself.w1 = nn.Linear(d_model,d_ff)self.w2 = nn.Linear(d_ff,d_model)#然后使用nn实例化dropout对象self.dropout = nn.Dropout(p=dropout)def forward(self,x):#输入参数x表示上一层的输出#首先经过第一个线性层,然后使用Funtional中的relu函数激活#之后在使用dropout进行随即处置0,最后经过第二个线性层w2,返回结果return self.w2(self.dropout(F.relu(self.w1(x))))
#实例化参数
d_model = 512
d_ff = 128
dropout = 0.2
x = mha_result
pwf = PositionwiseFeedForward(d_model,d_ff,dropout=dropout)
pwf_result = pwf(x)
print(pwf_result)
print(pwf_result.shape)

规范化层

学习目标

        了解规范化层的作用

        掌握规范化层的实现过程

规范化层的作用

        它是所有深层网络模型都需要的标准网络层,因为随着网络层数的增加,通过多层计算后的参数可能开始出现过大的或者过小的情况,这样会导致模型收敛的非常慢,因此都会在一定层数后接入规范化层进行数值规范化,使其数值特征在合理范围内。

 规范化层的代码实现

#规范化层的代码实现
#通过layerNorm实现规范化操作的类
class LayerNorm(nn.Module):def __init__(self, features,eps=1e-6):#初始化参数:features表示词嵌入的维度,eps表示一个足够小的数,在规范化公式分母出现,防止分母为0super(LayerNorm,self).__init__()#根据features初始化两个张量a2,b2,第一个初始化张量为1张量。就是说里面所有元素都是1,第二个初始化为0张量#里面所有元素都为0,这两张张量就是规范化层的参数。#因为直接对上一层得到的结果做规范化计算,将改变的结果正常表征,因此就需要有参数作为调节因子,#使其既能满足规范化要求,又不能改变针对目标的表征。最后使用nn.parameter封装,代表他们是模型的参数self.a2 = nn.Parameter(torch.ones(features))self.b2 = nn.Parameter(torch.zeros(features))#把eps传入类中self.eps = epsdef forward(self,x):#输入的参数来自上一层的输出#在函数中,首先对输入变量x求其最后一个维度的均值,并保持如输入维度一致#接着再求最后一个维度的标准差,然后就是根据规范化公式,用x减去均值除以标准差获得规范化的结果#最后对结果乘以缩放系数,即a2,加上位移参数b2mean = x.mean(-1,keepdim=True) #-1表示最后一个维度std = x.std(-1,keepdim = True)return self.a2 *(x - mean) / (std + self.eps) + self.b2
features = d_model = 512
eps = 1e-6#x是前馈全连接层的输出
x = pwf_result
ln = LayerNorm(features,eps=eps)
ln_result = ln(x)
print(ln_result)
print(ln_result.shape)

子层(残差)连接结构

学习目标

        了解什么是子层连接结构

        掌握子层连接结构的实现过程

什么是子层(残差)连接结构

        输入到每个子层以及规范化层的过程中,还使用了残差链接(跳跃链接),我们把这一部分结构叫做子层连接结构(代表子层及其链接结构),在每个编码器层,都有两个子层,这两个子层加上周围的链接结构形成了两个子层链接结构

        子层链接结构图:

子层连接结构代码实现 

#子层连接结构
#使用SublayerConnection来实现子层连接结构的类
class SublayerConnection(nn.Module):def __init__(self, size,dropout=0.1):#输入两个参数size表示词嵌入的维度,dropout对输出矩阵的随机置0super(SublayerConnection,self).__init__()#实例化规范化层self.norm = LayerNorm(size)#使用nn中预定义的dropout实例化一个self.dropout对象self.dropout = nn.Dropout(p=dropout)def forward(self,x,sublayer):#该逻辑函数中,接收上一个子层或者子层的输入作为第一个参数#将该子层连接中的子层函数作为第二个参数#我们首先对输入做规范化处理,然后将结果传入子层做处理,之后在再对子层进行dropout操作#随机停止一些网络中的神经元的作用,防止过拟合,最后还有一个add的操作#因为存在跳跃连接,所以将输入x与dropout后的子层输出结果相加作为最终子层的连接输出return x + self.dropout(sublayer(self.norm(x)))
#实例化参数
size = 512
dropout = 0.2
head = 8
d_model = 512
#输入参数
#令x为编码器的输出
x = pe_result
#print(x)
mask = Variable(torch.zeros(2,4,4))#假设子层中装的是多层注意力层,实例化
self_attn = MutiHeadedAttention(head,d_model)#使用lambda获得一个函数类型的子层
sublayer = lambda x: self_attn(x,x,x,mask)sc = SublayerConnection(size=size,dropout=dropout)
sc_result = sc(x,sublayer)
print(sc_result)
print(sc_result.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/27923.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Django视图与URLs路由详解

在Django Web框架中,视图(Views)和URLs路由(URL routing)是Web应用开发的核心概念。它们共同负责将用户的请求映射到相应的Python函数,并返回适当的响应。本篇博客将深入探讨Django的视图和URLs路由系统&am…

串口通讯基础

第1章 串口的发送和接收过程 1.1 串口接收过程 当上位机给串口发送(0x55)数据时,MCU的RX引脚接受到(0x55)数据,数据(0x55)首先进入移位寄存器。数据全部进入移位寄存器后,一次将(0x55)全部搬运…

kakfa-3:ISR机制、HWLEO、生产者、消费者、核心参数负载均衡

1. kafka内核原理 1.1 ISR机制 光是依靠多副本机制能保证Kafka的高可用性,但是能保证数据不丢失吗?不行,因为如果leader宕机,但是leader的数据还没同步到follower上去,此时即使选举了follower作为新的leader&#xff…

基于Linux系统的物联网智能终端

背景 产品研发和项目研发有什么区别?一个令人发指的问题,刚开始工作时项目开发居多,认为项目开发和产品开发区别不大,待后来随着自身能力的提升,逐步感到要开发一个好产品还是比较难的,我认为项目开发的目的…

STM32——DMA详解

目录 一:DMA简介 二:DMA基本结构 三:DMA实现过程 1.框图 2.DMA进行转运的条件 四:函数 一:DMA简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设存储器或者存储器和存储器之间的高速数据传输&…

告别卡顿,拥抱流畅!MemReduct——内存清理工具

先给安装包下载地址:MemReduct.exe下载,无脑下一步安装即可。 MemReduct 是一款出色的内存清理工具,以下是对它的详细介绍: 功能特点 高效内存清理:采用先进算法及系统底层 API,能智能清理系统缓存、应用…

告别GitHub连不上!一分钟快速访问方案

一、当GitHub抽风时,你是否也这样崩溃过? 😡 npm install卡在node-sass半小时不动😭 git clone到90%突然fatal: early EOF🤬 改了半天hosts文件,第二天又失效了... 根本原因:传统代理需要复杂…

指纹细节提取(Matlab实现)

指纹细节提取概述指纹作为人体生物特征识别领域中应用最为广泛的特征之一,具有独特性、稳定性和便利性。指纹细节特征对于指纹识别的准确性和可靠性起着关键作用。指纹细节提取,即从指纹图像中精确地提取出能够表征指纹唯一性的关键特征点,是…

【对话推荐系统综述】A Survey on Conversational Recommender Systems

文章信息: 发表于:ACM Computing Surveys 2021 原文链接:https://arxiv.org/abs/2004.00646 Abstract 推荐系统是一类软件应用程序,旨在帮助用户在信息过载的情况下找到感兴趣的项目。当前的研究通常假设一种一次性交互范式&am…

【0001】初识Java

Java是世界上最好的语言,没有之一!!! Java是世界上最好的语言,没有之一!!! Java是世界上最好的语言,没有之一!!! 重要的事情说三遍&am…

全向广播扬声器在油气田中的关键应用 全方位守护安全

油气田作为高风险作业场所,安全生产始终是重中之重。在紧急情况下,如何快速、有效地传达信息,确保人员安全撤离,是油气田安全管理的关键环节。全向广播扬声器凭借其全方位覆盖、高音质输出和强大的环境适应性,成为油气…

显式 GC 的使用:留与去,如何选择?

目录 一、什么是显式 GC? (一) 垃圾回收的基本原理 (二)显式 GC 方法和行为 1. System.gc() 方法 2. 显式 GC 的行为 (三)显式 GC 的使用场景与风险 1. JVM 如何处理显式 GC 2. 显式 GC…

基于vue框架的游戏商城系统cq070(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表 项目功能:用户,分类,商品信息,游戏高手,游戏代练 开题报告内容 基于Vue框架的游戏商城系统开题报告 一、研究背景与意义 随着互联网技术的飞速发展和游戏产业的蓬勃兴起,游戏商城作为游戏产业链中的重要一环,迎来了前所…

【OpenCV】OpenCV指南:图像处理基础及实例演示

OpenCV 是一个功能强大且易于使用的库,广泛应用于图像处理和计算机视觉领域。从读取和显示图像,到颜色空间转换、图像缩放、翻转、边缘检测、高斯模糊、形态学操作以及图像平滑和绘制,本文详细介绍了 OpenCV 的基础使用方法,附带了…

网络安全数据富化 网络数据安全处理规范

本文件规定了网络运营者开展网络数据收集、存储、使用、加工、传输、提供、公开等数据处理的安全 技术与管理要求。 本文件适用于网络运营者规范网络数据处理,以及监管部门、第三方评估机构对网络数据处理进行 监督管理和评估。 部分术语和定义 数据(data&#x…

蓝桥杯备考:动态规划线性dp之下楼梯问题进阶版

老规矩,按照dp题的顺序 step1 定义状态表达 f[i]表示到第i个台阶的方案数 step2:推导状态方程 step3:初始化 初始化要保证 1.数组不越界 2.推导结果正确 如图这种情况就越界了,我们如果把1到k的值全初始化也不现实,会增加程序的时间复杂度…

springboot + mybatis-plus + druid

目录架构 config MyMetaObjectHandler.java package com.example.config;import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler; import org.apache.ibatis.reflection.MetaObject; import org.springframework.stereotype.Component;import java.util.Date;Com…

UniApp 中封装 HTTP 请求与 Token 管理(附Demo)

目录 1. 基本知识2. Demo3. 拓展 1. 基本知识 从实战代码中学习,上述实战代码来源:芋道源码/yudao-mall-uniapp 该代码中,通过自定义 request 函数对 HTTP 请求进行了统一管理,并且结合了 Token 认证机制 请求封装原理&#xff…

【HarmonyOS Next】自定义Tabs

背景 项目中Tabs的使用可以说是特别的频繁,但是官方提供的Tabs使用起来,存在tab选项卡切换动画滞后的问题。 原始动画无法满足产品的UI需求,因此,这篇文章将实现下面页面滑动,tab选项卡实时滑动的动画效果。 实现逻…

RMSNorm模块

目录 代码代码解释1. 初始化方法 __init__2. 前向传播方法 forward3. 总结4. 使用场景 可视化 代码 class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float):super().__init__()self.eps epsself.weight nn.Parameter(torch.ones(dim))def forward(self,…