Transformer 模型介绍(四)——编码器 Encoder 和解码器 Decoder

上篇中讲完了自注意力机制 Self-Attention 和多头注意力机制 Multi-Head Attention,这是 Transformer 核心组成部分之一,在此基础上,进一步展开讲一下编码器-解码器结构(Encoder-Decoder Architecture)

Transformer 模型由以下两个主要部分组成:

  • 编码器(Encoder):负责处理输入句子,将其转化为一个上下文丰富的表示
  • 解码器(Decoder):根据编码器生成的上下文向量来生成目标语言句子

目录

1 编码器 Encoder

1.1 整体结构

1.2 多头注意力机制 Multi-Head Attention

1.3 残差连接与归一化 Add & Norm

1.4 前馈神经网络 FFN

1.5 重复堆叠

2 解码器 Decoder

2.1 整体结构

2.2 掩码多头注意力 Masked Multi-Head Attention

2.3 编码器-解码器注意力 Encoder-Decoder Attention

2.4 前馈神经网络 FFN

2.5 堆叠与生成

3 小结


1 编码器 Encoder

1.1 整体结构

在Transformer模型中,编码器(Encoder)部分是至关重要的组件,它负责接收输入序列并对其进行逐层处理,生成高质量的序列表示。与传统的序列到序列模型(如RNN、LSTM)不同,Transformer的编码器不依赖于递归结构,而是通过自注意力机制实现对序列中各个单词的动态关联。具体而言:

  • 编码器由多个相同的层堆叠而成,每个层由两个主要部分组成:多头自注意力(Multi-Head Self-Attention)机制和前馈神经网络(Feed-Forward Neural Network, FFNN)
  • 每个子层后都会接一个归一化操作,旨在稳定训练过程,确保模型在深度训练时不会出现梯度爆炸或消失的情况
  • 多头注意力机制由多头注意力层和归一化处理相连接,接着是一个全连接的前馈网络,共同构成了编码器的核心结构
  • 多头注意力层对输入句子中的特定单词向量计算注意力分数,将所有单词向量的注意力分数编码为一个新的隐藏状态向量,发送到前馈神经网络,进行线性映射
  • 多头注意力层根据输入句子中不同的单词得出不同的注意力分数,因此其权重参数不同,而前馈神经网络对输入句子中不同的单词应用完全相同的权重参数

1.2 多头注意力机制 Multi-Head Attention

上一篇中已详细解释了多头注意力机制,此处不再赘述

1.3 残差连接与归一化 Add & Norm

为了避免深度网络中的梯度消失问题,残差连接被引入到每个子层中。具体来说,每个子层(包括多头自注意力和前馈神经网络)之后都有一个残差连接,将该子层的输入与输出相加,从而保留输入的信息。这种设计有助于缓解深度神经网络训练时梯度消失或爆炸的风险,使得模型能够稳定地训练

在残差连接之后,应用层归一化(Layer Normalization)对数据进行规范化,确保每一层的输出数据分布保持稳定。层归一化的作用是将每一层的输出数据重新调整,使得其均值为0,方差为1,从而加快训练过程,并提高训练的稳定性。归一化操作有助于减少不同训练阶段可能引起的数据分布偏差,提升模型的泛化能力

1.4 前馈神经网络 FFN

每个编码器层的第二个子层是一个全连接的前馈神经网络(FFN)。前馈神经网络的作用是引入非线性转换,以提升模型的表达能力,学习更加复杂的特征

在标准的 FFN 中,通常包含两个线性变换层,中间夹着一个非线性激活函数(如 ReLU)。FFN 的结构通过对输入信息进行两次线性变换和一次非线性映射,能够进一步丰富输入特征的表示

前馈神经网络的具体计算过程如下:

\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2

其中,W1​ W2 是学习得到的权重矩阵,b1​ 和 b2​ 是偏置项。通过这种非线性映射,FFN 能够学习到更复杂的模式和特征,使得模型具有更强的拟合能力

1.5 重复堆叠

上述结构(多头自注意力、前馈神经网络、归一化和残差连接)会在整个Transformer编码器中重复多次(通常为6次),每次迭代都会对输入的序列表示进行更加深入的处理。通过这种逐层堆叠的方式,模型能够不断地提取更加高层次的特征表示,从而构建出对输入序列的深刻理解

每一层的输出会作为下一层的输入,逐层传递,通过多次迭代,模型逐步提升对输入序列的表示能力

2 解码器 Decoder

2.1 整体结构

Transformer 模型的解码器(Decoder)负责生成输出序列,其设计具有独特性,旨在确保输出的顺序性并有效利用编码器产生的上下文信息

解码器与编码器类似,也由多个相同的层堆叠而成,每一层包含三个关键模块

编码器与解码器的最大不同之处是,解码器使用了红色箭头所指的掩码多头注意力机制

2.2 掩码多头注意力 Masked Multi-Head Attention

解码器的第一部分是一个特殊的多头自注意力(Masked Multi-Head Attention)层,旨在引入“未来遮蔽”(Future Masking)机制。这意味着在计算当前单词(或Token)的注意力分数时,模型不会访问未来的词,从而保证了解码过程中的时序性

具体来说,掩码多头注意力机制的工作原理是:在解码时,模型只能使用当前和过去的单词信息来预测下一个单词,而不能提前看到未来的单词。通过将当前单词的查询向量(Query)与未来单词的键向量(Key)遮蔽(即将其置为负无穷大),模型只能基于已生成的单词来计算注意力分数

这种设计确保了生成过程是顺序的,避免了信息泄露,确保了每一步的预测只能基于先前生成的内容。这种“未来遮蔽”机制是Transformer解码器与编码器的一个关键区别

具体来说,使用掩码矩阵(Mask Matrix)用于阻止模型在计算注意力分数时访问未来的位置。掩码矩阵的维度与输入矩阵相同,在该矩阵中,未来位置的值被设置为负无穷(-inf)。通过这种方式,模型在进行 softmax 操作时,未来位置的注意力分数会趋近于零,确保解码器不会利用未来的信息。对于计算得到的注意力分数矩阵 Q \cdot K^T,我们将其与掩码矩阵按位相乘,得到掩码后的注意力分数矩阵 \text{Mask}(Q \cdot K^T)

\text{Mask}(Q \cdot K^T) = (Q \cdot K^T) \cdot M

其中,M 是掩码矩阵,在掩码矩阵中,未来位置的值为0,已生成位置的值为1。通过按位相乘,未来的位置的注意力分数会被置为负无穷

2.3 编码器-解码器注意力 Encoder-Decoder Attention

解码器的第二部分是编码器-解码器注意力(Encoder-Decoder Attention),这是 Transformer 解码器的重要组成部分。在这个层中,查询(Query)矩阵来自解码器的前一层输出,而键(Key)和值(Value)矩阵则直接来自编码器的最终输出矩阵 C

  • 这种设置使得解码器能够根据当前的解码状态,从编码器生成的全局上下文中提取相关信息
  • 解码器能够在生成每个新的单词时,有选择地从编码器的输出中提取上下文信息,以帮助生成更准具有全局一致性的输出序列

2.4 前馈神经网络 FFN

解码器的最后一个模块与编码器相同,是一个全连接的前馈网络

和编码器一样,解码器的每层之间也通过跨层方法,如残差连接(Residual Connections)和层归一化(Layer Normalization)相连,以促进梯度流动并保持输出的稳定性

2.5 堆叠与生成

Transformer解码器与编码器一样,由多个相同的解码器层堆叠而成(通常为6层)。每一层的输出作为下一层的输入,在每一层中,解码器会逐步地生成更丰富的序列表示

通过多层堆叠和反复处理,解码器能够将前面生成的单词信息与编码器提供的上下文信息进行结合,逐步生成一个完整的输出序列。在生成过程中,每一层都能够精细调整输出序列,确保生成的内容在语法和语义上与目标序列保持一致

3 小结

最终的 Transformer 模型由6层网络结构堆叠而成

从整体上理解,Transformer的 架构设计简洁且高效,主要由以下几个模块组成:将多个 self-attention 堆成 Multi-Head Attention,加上 Add & Norm 就构成了 Encoder。经过掩码操作后的Masked Multi-Head Attention 加上 Encoder 同款结构,就构成了 Decoder

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/20225.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电脑系统损坏,备份文件

一、工具准备 1.U盘:8G以上就够用,注意会格式化U盘,提前备份U盘内容 2.电脑:下载Windows系统并进行启动盘制作 二、Windows启动盘制作 1.微软官网下载启动盘制作工具微软官网下载启动盘制作工具https://www.microsoft.com/zh-c…

Linux下Ollama下载安装速度过慢的解决方法

问题描述:在Linux下使用默认安装指令安装Ollama,下载安装速度过慢,进度条进度缓慢,一直处于Downloading Linux amd64 bundle中,具体如下图所示: 其中,默认的Ollama Linux端安装指令如下&#xf…

uniapp中@input输入事件在修改值只有第一次有效的问题解决

在uniapp中使用输入框,要求输入不超过7个字,所以需要监听输入事件,当每次输入文字的时候,就把输入的值截断,取前7个值。但是在input事件中,重新赋值的值发生了变化,但是页面上的还是没有变&…

DeepSeek 助力 Vue 开发:打造丝滑的范围选择器(Range Picker)

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…

VMware按照的MacOS升级后无法联网

背景 3年前公司使用Flutter开发了一款app,现在app有微小改动需要重新发布到AppStore 问题 问题是原来的Vmware搭建的开发环境发布App失败了 提示:App需要使用xcode15IOS 17 SDK重新构建,这样的话MacOS至少需要升级到13.5 Xcode - 支持 - Ap…

Day01 【苍穹外卖】环境搭建与前后端联调

一、环境搭建 1.JDK安装与IDEA安装: JDK安装与IDEA安装:【JAVA基础】01、JAVA环境配置----JDK与 IDEA集成开发环境的安装(2025最新版本)_配置jdk-CSDN博客 注意,这里要下载JDK1.8版本的,不然会报错&…

STM32 HAL库USART串口中断编程:环形缓冲区防止数据丢失

HAL_UART_Receive接收最容易丢数据了,可以考虑用中断来实现,但是HAL_UART_Receive_IT还不能直接用,容易数据丢失,实际工作中不会这样用,STM32 HAL库USART串口中断编程:演示数据丢失,需要在此基础优化一下. 本文介绍STM32F103 HAL库USART串口中断,利用环形缓冲区来防…

Vulnhub:DC-1靶机渗透

渗透过程 一,信息收集 1,探测目标IP地址 探测目标IP地址,探测主机的工具有很多,常见的有arp-scan、nmap还有netdiscover,这里使用arp-scan arp-scan -l确定了DC-1主机的IP地址为 192.168.126.1452,探测…

MySQL 之存储引擎(MySQL Storage Engine)

MySQL 之存储引擎 常见存储引擎及其特点 ‌InnoDB‌: ‌特点‌:支持事务处理、行级锁定、外键约束,使用聚簇索引,适合高并发读写和事务处理的场景‌。‌适用场景‌:需要高可靠性、高并发读写和事务处理的场景‌。 ‌M…

EasyX安装及使用

安装链接:EasyX Graphics Library for C 安装完成包含头文件graphics.h即可使用 RGB合成颜色(红色部分,绿色部分,蓝色部分) 每种颜色的值都是(0~255) 坐标默认的原点在窗口的左上角&#xf…

20.【线性代数】——坐标系中,平行四边形面积=矩阵的行列式

三 坐标系中,平行四边形面积矩阵的行列式 定理验证 定理 在坐标系中,由向量(a,b)和向量(c,d)组成平行四边形的面积 矩阵 [ a b c d ] \begin{bmatrix} a&b\\ c&d \end{bmatrix} [ac​bd​]的行列式,即&#x…

Grafana——Rocky9安装Grafana相关步骤记录

安装Grafana 安装 直接进下面这个页面,可以看到这边可以选择版本以及操作系统 并且如果是Linux平台的,下面会给出不同平台的命令,直接复制粘贴执行一下就可以了! 验证 运行命令 ## 运行service systemctl start grafana-server## 自启…

Mathtype安装入门指南

Mathtype安装入门指南 1 mathtype安装及补丁2 mathtype在word中加载3 常见的mathtype快捷命令4 实列测试 1 mathtype安装及补丁 下载相应的Mathtype7.4软件安装包,百度网盘链接为: 百度网盘链接下载完成后,有三个软件,如下图所示…

ConcurrentHashMap 在Jdk 17 不同版本中的优化和改进

ConcurrentHashMap 是 Java 中的一个高性能线程安全的哈希表实现,随着 JDK 版本的迭代,其内部实现也经历了多次优化和改进。每个版本的改动针对不同的场景和需求进行了性能提升和问题修复。以下分别描述了 JDK 7、JDK 8 和 JDK 17 的主要设计和区别&…

普通报表入门

1. 概述 报表设计主要可以分为新建报表、数据准备、报表主体设计、报表预览几大部分。其中报表主体可以分为大标题、小标题、表格数据、结尾几大部分,本文主要以普通报表为例,讲述如何按照报表设计流程快速设计一张报表。FineReport 版本为11.0 1.1 预期…

用deepseek学大模型08-cnn残差网络

残差网络 参考:https://blog.csdn.net/2301_80750681/article/details/142882802 以下是使用PyTorch实现的三层残差网络示例,包含三个残差块和完整的网络结构: import torch import torch.nn as nnclass BasicBlock(nn.Module):expansion…

AIGC(生成式AI)试用 21 -- Python调用deepseek API

1. 安装openai pip3 install openai########################## Collecting openaiUsing cached openai-1.61.1-py3-none-any.whl.metadata (27 kB) Collecting anyio<5,>3.5.0 (from openai)Using cached anyio-4.8.0-py3-none-any.whl.metadata (4.6 kB) Collecting d…

分享一款AI绘画图片展示和分享的小程序

&#x1f3a8;奇绘图册 【开源】一款帮AI绘画爱好者维护绘图作品的小程序 查看Demo 反馈 github 文章目录 前言一、奇绘图册是什么&#xff1f;二、项目全景三、预览体验3.1 截图示例3.2 在线体验 四、功能介绍4.1 小程序4.2 服务端 五、安装部署5.1 快速开始~~5.2 手动部…

node.js + html调用ChatGPTApi实现Ai网站demo(带源码)

文章目录 前言一、demo演示二、node.js 使用步骤1.引入库2.引入包 前端HTML调用接口和UI所有文件总结 前言 关注博主&#xff0c;学习每天一个小demo 今天是Ai对话网站 又到了每天一个小demo的时候咯&#xff0c;前面我写了多人实时对话demo、和视频转换demo&#xff0c;今天…

Java基础(其一)

1.八个基础数据类型&#xff1a; 整数型&#xff1a;int long short byte 浮点型&#xff1a;float double 字符型&#xff1a;char 布尔型&#xff1a;bool 1.1. byte 范围&#xff1a;-128 到 127&#xff08;8位&#xff0c;有符号&#xff09; 用途&#xff1a; 小范围…