Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的dec_inputs: [batch_size, tgt_len]enc_inpus: [batch_size, src_len]enc_outputs: [batch_size, src_len, d_model]'''dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵# 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为Falsedec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +dec_self_attn_subsequence_mask), 0).cuda()# 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]for layer in self.layers:dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]],[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]],[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的attn_mask: [batch_size, n_heads, seq_len, seq_len]'''# 1) 计算注意力分数QK^T/sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]# 2)  进行 mask 和 softmax# mask为True的位置会被设为-1e9scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]# 3) 乘V得到最终的加权和context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]'''得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来'''return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/254851.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云游戏服务器多少钱一个月?

阿里云游戏服务器租用价格表:4核16G服务器26元1个月、146元半年,游戏专业服务器8核32G配置90元一个月、271元3个月,阿里云服务器网aliyunfuwuqi.com分享阿里云游戏专用服务器详细配置和精准报价: 阿里云游戏服务器租用价格表 阿…

【QT+QGIS跨平台编译】之三十一:【FreeXL+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

文章目录 一、FreeXL介绍二、文件下载三、文件分析四、pro文件五、编译实践一、FreeXL介绍 【FreeXL跨平台编译】:Windows环境下编译成果(支撑QGIS跨平台编译,以及二次研发) 【FreeXL跨平台编译】:Linux环境下编译成果(支撑QGIS跨平台编译,以及二次研发) 【FreeXL跨平台…

Linux基础-配置网络

Linux配置网络的方式 1.图形界面 右上角-wired-配置 点加号-新建网络配置文件2.NetworkManager工具 2.1用图形终端nmtui 1.新建网络配置文件add 1.指定网络设备的类型Ethernet 2.配置网络配置文件的名称,名称可以有空格 3.配置网络配置文件对应的物理网络设备的…

清平乐-春风丽日

今天,是2024年农历除夕日,远方家人已于昨夜风尘扑扑地倦鸟归巢,团聚过龙年,今晨酣睡未起。老龄笔者心情极佳,一夜好梦醒来,推窗仰头展望苍穹,喜上心头:啊!接连几天的小雨…

如何在 Mac 上恢复永久删除的文件:有效方法

您是否错误地从 Mac 中删除了某个文件,并且确信它已经永远消失了?好吧,你可能错了。即使您认为已永久删除计算机上的数据,仍有可能将其恢复。 在本文中,您将了解如何在 Mac 上恢复永久删除的文件,并了解增…

vim常用命令以及配置文件

layout: article title: “vim文本编译器” vim文本编辑器 有三种模式: 命令模式 文本模式, 末行模式 vim命令大全 - 知乎 (zhihu.com) 命令模式 插入 i: 切换到输入模式,在光标当前位置开始输入文本。 a: 进入插入模式,在光标下一个位置开始输入文…

C++多线程:this_thread 命名空间

std::this_thread 是 C 标准库中提供的一个命名空间,它包含了与当前线程相关的功能。这个命名空间提供了许多与线程操作相关的工具,使得在多线程环境中更容易进行编程。 源码类似于如下: namespace std{namespace this_thread{//...........…

Swift Combine 管道 从入门到精通三

Combine 系列 Swift Combine 从入门到精通一Swift Combine 发布者订阅者操作者 从入门到精通二 1. 用弹珠图描述管道 函数响应式编程的管道可能难以理解。 发布者生成和发送数据,操作符对该数据做出响应并有可能更改它,订阅者请求并接收这些数据。 这…

《学成在线》微服务实战项目实操笔记系列(P1~P83)【上】

史上最详细《学成在线》项目实操笔记系列【上】,跟视频的每一P对应,全系列12万字,涵盖详细步骤与问题的解决方案。如果你操作到某一步卡壳,参考这篇,相信会带给你极大启发。 一、前期准备 1.1 项目介绍 P2 To C面向…

vue3 之 商城项目—详情页

整体认识 路由配置 准备组件模版 <script setup></script><template><div class"xtx-goods-page"><div class"container"><div class"bread-container"><el-breadcrumb separator">">&…

引用论文bib信息解析,查看年、卷、期

2022,53(01):69-74 53卷 01期 volume{29},卷 number{},

SpringBoot配置文总结

官网配置手册 官网&#xff1a;https://spring.io/ 选择SpringBoot 选择LEARN 选择 Application Properties 配置MySQL数据库连接 针对Maven而言&#xff0c;会搜索出两个MySQL的连接驱动。 com.mysql mysql-connector-j 比较新&#xff0c;是在mysql mysql-connect…

2.6:冒泡、简选、直插、快排,递归,宏

1.冒泡排序、简单选择排序、直接插入排序、快速排序(升序) 程序代码&#xff1a; 1 #include<stdio.h>2 #include<string.h>3 #include<stdlib.h>4 void Bubble(int arr[],int len);5 void simple_sort(int arr[],int len);6 void insert_sort(int arr[],in…

ArcGIS Pro 按照字段进行融合或拆分

ArcGIS Pro 按字段融合 在ArcGIS Pro中&#xff0c;通过使用“融合”工具可以轻松地合并具有相同字段的图层。 步骤一&#xff1a;打开ArcGIS Pro 启动ArcGIS Pro应用程序&#xff0c;确保您已经登录并打开您的项目。 步骤二&#xff1a;添加图层 将包含相同字段的图层添加到…

2024牛客寒假算法基础集训营1(视频讲解全部题目)

2024牛客寒假算法基础集训营1&#xff08;题目全解&#xff09; ABCDEFGHIJKLM 2024牛客寒假算法基础集训营1&#xff08;视频讲解全部题目&#xff09; A #include<bits/stdc.h> #define endl \n #define deb(x) cout << #x << " " << …

探索Xposed框架:个性定制你的Android体验

探索Xposed框架&#xff1a;个性定制你的Android体验 1. 引言 在当今移动设备市场中&#xff0c;Android系统作为最受欢迎的操作系统之一&#xff0c;其开放性和可定制性备受用户青睐。用户希望能够根据个人喜好和需求对其设备进行定制&#xff0c;以获得更符合自己习惯的使用…

大模型基础架构的变革:剖析Transformer的挑战者(下)

上一篇文章中&#xff0c;我们介绍了UniRepLKNet、StripedHyena、PanGu-π等有可能会替代Transformer的模型架构&#xff0c;这一篇文章我们将要介绍另外三个有可能会替代Transformer的模型架构&#xff0c;它们分别是StreamingLLM、SeTformer、Lightning Attention-2&#xff…

LabVIEW伺服阀性能参数测试

LabVIEW伺服阀性能参数测试 伺服阀作为电液伺服系统中的核心元件&#xff0c;其性能参数的准确测试对保证系统整体性能至关重要。开发了一种基于LabVIEW软件开发的伺服阀性能参数测试系统&#xff0c;提高测试的自动化程度和精确性&#xff0c;同时降低操作复杂度和成本。 传…

Python中使用opencv-python库进行颜色检测

Python中使用opencv-python库进行颜色检测 之前写过一篇VC中使用OpenCV进行颜色检测的博文&#xff0c;当然使用opencv-python库也可以实现。 在Python中使用opencv-python库进行颜色检测非常简单&#xff0c;首选读取一张彩色图像&#xff0c;并调用函数imgHSV cv2.cvtColor…

《低功耗方法学》翻译——附录B:UPF命令语法

附录B&#xff1a;UPF命令语法 本章介绍了文本中引用的所选UPF命令的语法。 节选自“统一电源格式&#xff08;UPF&#xff09;标准&#xff0c;1.0版”&#xff0c;经该Accellera许可复制。版权所有&#xff1a;(c)2006-2007。Accellera不声明或代表摘录材料的准确性或内容&…