Transformer详解encoder

目录

1. Input Embedding

2. Positional Encoding

3. Multi-Head Attention

4. Add & Norm

5. Feedforward + Add & Norm

6.代码展示

(1)layer_norm

(2)encoder_layer=1


最近刚好梳理了下transformer,今天就来讲讲它~

        Transformer是谷歌大脑2017年在论文attention is all you need中提出来的seq2seq模型,它的本质就是由编码器和解码器组成,今天的主角则是其中的编码器(在BERT预训练模型中也只用到了编码器部分)如下图所示,这个模块的输入为 𝑋 (每一行代表一个句子,batchsize有多大就有多少行),我们将从输入到隐藏层按照从1到4的顺序逐层来看一下各个维度的变化。

1. Input Embedding

        所谓的Embedding其实就是查字典或者叫查表,也就是将一个句子里的每一个字转化为一个维度为embedding dimension的向量来表示,因此 𝑋 经过嵌入后变成 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 ,三个维度分别表示一个批次的句子数,每个句子的字数,每个字的嵌入维度。

2. Positional Encoding

        位置编码,按照字面意思理解就是给输入的位置做个标记,简单理解比如你就给一个字在句子中的位置编码1,2,3,4这样下去,高级点的比如作者用的正余弦函数

𝑃𝐸(𝑝𝑜𝑠,2𝑖)=𝑠𝑖𝑛(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

𝑃𝐸(𝑝𝑜𝑠,2𝑖+1)=𝑐𝑜𝑠(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

 

        其中pos表示字在句子中的位置,i指的词向量的维度。经过位置编码,相当于能够得到一个和输入维度完全一致的编码数组 𝑋𝑝𝑜𝑠 ,当它叠加到原来的词嵌入上得到新的词嵌入

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝑋𝑝𝑜𝑠

        此时的维度为:一个批次的句子数 × 一个句子的词数 × 一个词的嵌入维度

3. Multi-Head Attention

        注意力机制,其实可以理解为就是在计算相关性,很自然的想法就是去更多地关注那些相关更大的东西。这里首先要引入Query,Key和Value的概念,Query就是查询的意思,Key就是键用来和你要查询的Query做比较,比较得到一个分数(相关性或者相似度)再乘以Value这个值得到最终的结果。

        那么这个Q,K,V从哪里来呢,这里采用的是self-attention的方式,也就是从输入自己 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 来产生,即做线性映射产生Q,K,V:

𝑄=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑄𝐾=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝐾𝑉=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑉

        这里三个权重矩阵均为维度为Embedding的方阵,也就是说Q,K,V的维度和 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 是一致的。

        接下来考虑什么叫做multi-head(多头)呢,本质上就是从embedding的维度上将矩阵切分为多份,每一份就是一个头,比如之前的Q,K,V切完后的维度就是一个批次的句子数 × 一个句子的词数 × 头数 × (词嵌入维度/头数)这个多头的切分体现在最后两个维度:词嵌入维度=数 × (词嵌入维度/头数)为了便于计算,通常会将第二第三维度进行转置,即最终的维度为一个批次的句子数 × 头数 × 一个句子的词数 × (词嵌入维度/头数)

        接下来说说注意力机制的计算,假设Q,K,V为切分完后的矩阵(其中一个头),根据两个向量的点积越大越相似,我们通过 𝑄𝐾𝑇 求出注意力矩阵,再根据注意力矩阵来给Value进行加权,即

𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑇𝑑𝑘)𝑉

        其中 𝑑𝑘 是为了把注意力矩阵变成标准正态分布,softmax进行归一化,使每个字与其他字的注意力权重之和为1。这一操作使得每一个字的嵌入都包含当前句子内所有字的信息,注意Attention(Q,K,V)的维度和 𝑉 的维度保持一致。

4. Add & Norm

这里主要做了两个操作

  • 一个是残差连接(或者叫做短路连接),说得直白点就是把上一层的输入 𝑋 和上一层的输出加起来 𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,即 𝑋+𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,举例说明,比如在注意力机制前后的残差连接:

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)

  • 一个是LayerNormalization(作用是把神经网络中隐藏层归一为标准正态分布,加速收敛),具体操作是将每一行每一个元素减去这行的均值, 再除以这行的标准差, 从而得到归一化后的数值。

5. Feedforward + Add & Norm

前馈网络也就是简单的两层线性映射再经过激活函数一下,比如

𝑋ℎ𝑖𝑑𝑑𝑒𝑛=𝑅𝑒𝑙𝑢(𝑋ℎ𝑖𝑑𝑑𝑒𝑛∗𝑊1∗𝑊2)

残差操作和层归一化同步骤3.


上述的1,2,3,4就构成Transformer中的一个encoder模块,经过1,2,3,4后得到的就是encode后的隐藏层表示,可以发现它的维度其实和输入是一致的!即:一个批次中句子数 × 一个句子的字数 × 字嵌入的维度

6.代码展示

(1)layer_norm

bs=2,seq=3,dim=5

import torchbatch_size = 2
seq = 3
fea_dim = 5
X = torch.rand(batch_size,seq,fea_dim)
layer_norm = torch.nn.LayerNorm(fea_dim)
out = layer_norm(X)
print(out)
print('-'*30)mean = torch.mean(X,dim=-1,keepdim=True)
std = torch.sqrt(torch.var(X,unbiased=False,dim=-1,keepdim=True) + 1e-5)
weight = layer_norm.state_dict()['weight']
bias = layer_norm.state_dict()['bias']
my_norm = ((X - mean)/std) * weight + bias
print(my_norm)

(2)encoder_layer=1

bs=1,seq=1,dim=6,head=1

import torchseq = 1
dim = 6
heads = 1
batch_size = 1
value = torch.rand(batch_size,seq,dim)encoder_layer = torch.nn.TransformerEncoderLayer(dim,heads,dropout=0.0,batch_first=True)
out = encoder_layer(value)
print(out)# 多头自注意力
def my_scaled_dot_product(query,key,value):qk_T = torch.mm(query,key.T)qk_T_scale = qk_T / torch.sqrt(torch.tensor(value.shape[1]))qk_exp = torch.exp(qk_T_scale)qk_exp_sum = torch.sum(qk_exp,dim=1,keepdim=True)qk_softmax = qk_exp / qk_exp_sumv_attn = torch.mm(qk_softmax,value)return v_attn,qk_softmaxin_proj_weight = encoder_layer.state_dict()['self_attn.in_proj_weight']
in_proj_bias = encoder_layer.state_dict()['self_attn.in_proj_bias']out_proj_weight = encoder_layer.state_dict()['self_attn.out_proj.weight']
out_proj_bias = encoder_layer.state_dict()['self_attn.out_proj.bias']batch_V_output = torch.empty(batch_size,seq,dim)
for i in range(batch_size):in_proj = torch.mm(value[i],in_proj_weight.T) + in_proj_biasQs,Ks,Vs = torch.split(in_proj,dim,dim=-1)head_Vs = []attn_weight = torch.zeros(seq,seq)for Q,K,V in zip(torch.split(Qs,dim//heads,dim=-1),torch.split(Ks,dim//heads,dim=-1),torch.split(Vs,dim//heads,dim=-1)):head_v,_ = my_scaled_dot_product(Q,K,V)head_Vs.append(head_v)V_cat = torch.cat(head_Vs,dim=-1)V_ouput = torch.mm(V_cat,out_proj_weight.T) + out_proj_biasbatch_V_output[i] = V_ouput# 第一次加
first_Add = value + batch_V_output# 第一次layer_norm
norm1_mean = torch.mean(first_Add,dim=-1,keepdim=True)
norm1_std = torch.sqrt(torch.var(first_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm1_weight = encoder_layer.state_dict()['norm1.weight']
norm1_bias = encoder_layer.state_dict()['norm1.bias']
norm1 = ((first_Add - norm1_mean)/norm1_std) * norm1_weight + norm1_bias# feed forward
linear1_weight = encoder_layer.state_dict()['linear1.weight']
linear1_bias = encoder_layer.state_dict()['linear1.bias']
linear2_weight = encoder_layer.state_dict()['linear2.weight']
linear2_bias = encoder_layer.state_dict()['linear2.bias']
linear1 = torch.matmul(norm1,linear1_weight.T) + linear1_bias
linear1_relu = torch.nn.functional.relu(linear1)
linear2 = torch.matmul(linear1_relu,linear2_weight.T) + linear2_bias# 第二次加
second_Add = norm1 + linear2# 第二次layer_norm
norm2_mean = torch.mean(second_Add,dim=-1,keepdim=True)
norm2_std = torch.sqrt(torch.var(second_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm2_weight = encoder_layer.state_dict()['norm2.weight']
norm2_bias = encoder_layer.state_dict()['norm2.bias']
norm2 = ((second_Add - norm2_mean)/norm2_std) * norm2_weight + norm2_bias
print(norm2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/365360.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文件系统(操作系统实验)

实验内容 (1)在内存中开辟一个虚拟磁盘空间作为文件存储器, 在其上实现一个简单单用户文件系统。 在退出这个文件系统时,应将改虚拟文件系统保存到磁盘上, 以便下次可以将其恢复到内存的虚拟空间中。 (2&…

MySQL-java连接MySQL数据库+JDBC的使用

目录 1.准备所需要资源 2.导入驱动包 3.连接数据库步骤 首先在MySQL中创建好数据库和表 代码实现连接数据库 1.准备所需要资源 1.mysql和驱动包 我用的是5.7的mysql和5.1.49的驱动包,链接放在网盘里,需要的自取 链接:https://pan.bai…

二轴机器人装箱机:重塑物流效率,精准灵活,引领未来装箱新潮流

在现代化物流领域,高效、精准与灵活性无疑是各大企业追求的核心目标。而在这个日益追求自动化的时代,二轴机器人装箱机凭借其较佳的性能和出色的表现,正逐渐成为装箱作业的得力助手,引领着未来装箱新潮流。 一、高效:重…

【自动化测试】Selenium自动化测试框架 | 相关介绍 | Selenium + Java环境搭建 | 常用API的使用

文章目录 自动化测试一、selenium1.相关介绍1.Selenium IDE2.Webdriverwebdriver的工作原理: 3.selenium Grid 2.Selenium Java环境搭建3.常用API的使用1.定位元素2.操作测试对象3.添加等待4.打印信息5.浏览器的操作6.键盘事件7.鼠标事件8.定位一组元素9.多层框架定…

springcloud-config 客户端启用服务发现client的情况下使用metadata中的username和password

为了让spring admin 能正确获取到 spring config的actuator的信息,在eureka的metadata中添加了metadata.user.user metadata.user.password eureka.instance.metadata-map.user.name${spring.security.user.name} eureka.instance.metadata-map.user.password${spr…

HTTP协议和Nginx

一、HTTP协议和Nginx 1.套接字Socket 套接字Socket是进程间通信IPC的一种实现,允许位于不同主机(或同一主机)上不同进程之间进行通信和数据交换,SocketAPI出现于1983年BSD4.2实现在建立通信连接的每一端,进程间的传输…

【单元测试】Controller、Service、Repository 层的单元测试

Controller、Service、Repository 层的单元测试 1.Controller 层的单元测试1.1 创建一个用于测试的控制器1.2 编写测试 2.Service 层的单元测试2.1 创建一个实体类2.2 创建服务类2.3 编写测试 3.Repository 1.Controller 层的单元测试 下面通过实例演示如何在控制器中使用 Moc…

Uniapp 默认demo安装到手机里启动只能看得到底tab无法看到加载内容解决方案

Uniapp 默认demo安装到手机里以后,启动APP只能看到底tab栏,无法看到每个tab页对应的内容,HBuilder会有一些这样的报错信息: Waiting to navigate to: /pages/tabBar/API/API, do not operate continuously: 解决方案:…

OpenCV 调用自定义训练的 YOLO-V8 Onnx 模型

一、YOLO-V8 转 Onnx 在本专栏的前面几篇文章中,我们使用 ultralytics 公司开源发布的 YOLO-V8 模型,分别 Fine-Tuning 实验了 目标检测、关键点检测、分类 任务,实验后发现效果都非常的不错,但是前面的演示都是基于 ultralytics…

SpringBoot + mkcert ,解决本地及局域网(内网)HTTPS访问

本文主要解决访问SpringBoot开发的Web程序,本地及内网系统,需要HTTPS证书的问题。 我测试的版本是,其他版本不确定是否也正常,测试过没问题的小伙伴,可以在评论区将测试过的版本号留下,方便他人参考: <spring-boot.version>2.3.12.RELEASE</spring-boot.vers…

快速将网页封装成APP:小猪APP分发助您一臂之力

你是否曾经有一个绝妙的网页&#xff0c;但苦于无法将其变成手机APP&#xff1f;其实&#xff0c;你并不孤单。越来越多的企业和开发者希望将自己的网站封装成APP&#xff0c;以便更好地接触到移动端用户。我们就来聊聊如何快速将网页封装成APP&#xff0c;并探讨小猪APP分发在…

「C++系列」C++ 数据类型

文章目录 一、C 数据类型二、C 数据类型占位与范围三、类型转换1. 隐式类型转换&#xff08;Automatic Type Conversion&#xff09;2. 显式类型转换&#xff08;Explicit Type Conversion&#xff09;3. 示例代码 四、数据类型案例1. 整型2. 浮点型3. 字符型4. 布尔型5. 枚举类…

《Programming from the Ground Up》阅读笔记:p1-p18

《Programming from the Ground Up》学习第1天&#xff0c;p1-18总结&#xff0c;总计18页。 一、技术总结 1.fetch-execute cycle p9, The CPU reads in instructions from memory one at a time and executes them. This is known as the fetch-execute cycle。 2.genera…

九浅一深Jemalloc5.3.0 -- ①浅*编译调试

目前市面上有不少分析Jemalloc老版本的博文&#xff0c;但5.3.0却少之又少。而且5.3.0的架构与之前的版本也有较大不同&#xff0c;本着“与时俱进”、“由浅入深”的宗旨&#xff0c;我将逐步分析Jemalloc5.3.0的实现。5.3.0的特性请见Releases jemalloc/jemalloc GitHub 另…

fastapi+vue3前后端分离开发第一个案例整理

开发思路 1、使用fastapi开发第一个后端接口 2、使用fastapi解决cors跨域的问题。cors跨域是浏览器的问题&#xff0c;只要使用浏览器&#xff0c;不同IP或者不同端口之间通信&#xff0c;就会存在这个问题。前后端分离是两个服务&#xff0c;端口不一样&#xff0c;所以必须要…

Java单体架构项目_云霄外卖-特殊点

项目介绍&#xff1a; 定位&#xff1a; 专门为餐饮企业&#xff08;餐厅、饭店&#xff09;定制的一款软件商品 分为&#xff1a; 管理端&#xff1a;外卖商家使用 用户端&#xff08;微信小程序&#xff09;&#xff1a;点餐用户使用。 功能架构&#xff1a; &#xff08…

ESP32实现UDP连接——micropython版本

代码&#xff1a; import network import socket import timedef wifiInit(name, port):ap network.WLAN(network.AP_IF) # 创建一个热点ap.config(essidname, authmodenetwork.AUTH_OPEN) # 无需密码ap.active(True) # 激活热点ip ap.ifconfig()[0] # 获取ip地址print(…

【想起就补】整理了一些SSH中常用的命令

希望文章能给到你启发和灵感&#xff5e; 如果觉得文章对你有帮助的话&#xff0c;点赞 关注 收藏 支持一下博主吧&#xff5e; 阅读指南 开篇说明一、基础环境说明1.1 硬件环境1.2 软件环境 二、常用命令类型2.1 远程登录相关2.2 文件操作命令2.3 权限和所有权操作命令2.4 文…

MT6989(天玑9300)芯片性能参数_MTK联发科5G处理器

MT6989是联发科Dimensity旗舰系列的成员&#xff0c;旨在为旗舰5G智能手机供应商提供最先进的技术和性能。MT6989也是联发科目前最具创新和强大的5G智能手机芯片&#xff0c;具有领先的功耗效率&#xff0c;无与伦比的计算架构&#xff0c;有史以来最快和最稳定的5G调制解调器&…

【操作系统期末速成】 EP04 | 学习笔记(基于五道口一只鸭)

文章目录 一、前言&#x1f680;&#x1f680;&#x1f680;二、正文&#xff1a;☀️☀️☀️2.1 考点七&#xff1a;进程通信2.2 考点八&#xff1a;线程的概念2.3 考点九&#xff1a;处理机调度的概念及原则2.4 考点十&#xff1a;调度方式与调度算法 一、前言&#x1f680;…