用变压器实现德-英语言翻译【01/8】:嵌入层

 一、说明

        本文是“用变压器实现德-英语言翻译”系列的第一篇文章。它引入了小规模的嵌入来建立感知系统。接下来是嵌入层的变压器使用。下面简要概述了每种方法,然后是德语到英语的翻译。

二、技术背景

        嵌入层的目标是使模型能够详细了解单词、标记或其他输入之间的关系。此嵌入层可以被视为将数据从高维空间转换为低维空间,也可以视为将数据从低维空间映射到高维空间。

2.1 从单热向量到嵌入向量

        在自然语言处理中,令牌派生自可能包含章节、段落或句子的数据语料库。这些以各种方式分解成更小的部分,但最常见的标记化方法是按单词。语料库中所有独特的单词都被称为词汇表。

        词汇表中的每个单词都被分配一个整数,因为它更容易被计算机处理。有多种方法可以分配这些整数,但同样,最简单的方法是按字母顺序分配它们。

        下图演示了将较大的语料库分解为其组件并为每个组件分配整数的过程。请注意,为简单起见,标点符号被去掉,文本设置为小写。

        通过为每个单词分配索引而创建的数字顺序意味着一种关系。由于这不是意图,因此索引通常用于为每个单词创建一个独热编码向量。单热向量与词汇表的长度相同。在这种情况下,每个向量有 24 个元素。它被称为“一热”向量,因为只有一个元素被“打开”或设置为 1;所有其他令牌都处于“关闭”状态或设置为 0。1 的索引对应于分配给单词的整数值。通常,模型学习预测向量中给定索引的最高概率。

         当一个模型只有十几个标记或类可供预测时,独热编码向量通常是一种方便的表示形式。但是,大型语料库可以有数十万个代币。不是使用充满零的稀疏向量,这些向量没有传达太多意义,而是使用嵌入层将向量映射到较小的维度。可以训练这些嵌入式向量来传达有关每个单词及其与其他单词的关系的更多信息。

        本质上,每个单词都由一个d_model维向量表示,其中d_model可以是任何数字。它只是指示嵌入维度的数量。如果d_model是 2 或 3,则可以可视化每个单词之间的关系,但通常根据任务使用 256、512 和 1024 的值。

        下面可以看到一个优化嵌入的示例,其中类似类型的书籍彼此靠近嵌入:

2.2 嵌入向量

        嵌入矩阵的大小为 (vocab_size, d_model)。这允许将大小为 (seq_length, vocab_size) 的单热向量矩阵乘以它以获得新的嵌入式表示。序列长度由 seq_length 表示,即序列中的标记数。请记住,到目前为止,可视化中的“序列”是整个词汇表。在实践中,将使用词汇的子集,例如“基本段落”。该序列将被标记化、索引并转换为独热编码向量矩阵。然后,这些独热编码向量将能够与嵌入矩阵相乘。

        嵌入序列的大小为 (seq_length, vocab_size) x (vocab_size, d_model = (seq_length, d_model)。这意味着句子中的每个单词现在都由d_model维向量表示,而不是vocab_size元素的独热编码向量。下面可以看到此矩阵乘法的示例。索引序列的形状为 (3,24),嵌入矩阵的形状为 (24, 3)。一旦它们相乘,输出就是一个 (3,3) 矩阵。每个单词都由其 3 元素嵌入向量表示。

        当独热编码矩阵与嵌入层相乘时,将返回嵌入层的相应向量,而不进行任何更改。下面是独热编码向量和嵌入矩阵的整个词汇表之间的矩阵乘法。输出是嵌入矩阵。

        这表明有一种更简单的方法可以在不使用矩阵乘法的情况下获取这些相同的值,因为矩阵乘法可能会占用大量资源。分配给每个单词的整数可用于直接索引嵌入矩阵,而不是从 one-hot 编码向量转到 d_model 维嵌入(从较大维度到较小维度)。这就像从一维转到d_model维,提供有关令牌的更多信息。

        下图显示了如何在不乘法的情况下获得完全相同的结果:

2.3 从头开始嵌入

        可以在 Python 中创建上述图的简单实现。嵌入序列需要一个分词器、单词及其索引的词汇表,以及词汇表中每个单词的三维嵌入。分词器将序列拆分为其标记,在本示例中为小写单词。下面的简单函数从序列中删除标点符号,将其拆分为标记,并将它们小写。

# importing required libraries
import math
import copy
import numpy as np# torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor# visualization packages
from mpl_toolkits import mplot3d
import matplotlib.pyplot as pltexample = "Hello! This is an example of a paragraph that has been split into its basic components. I wonder what will come next! Any guesses?"def tokenize(sequence):# remove punctuationfor punc in ["!", ".", "?"]:sequence = sequence.replace(punc, "")# split the sequence on spaces and lowercase each tokenreturn [token.lower() for token in sequence.split(" ")]tokenize(example)
['hello', 'this', 'is', 'an', 'example', 'of', 'a', 'paragraph', 'that', 
'has', 'been', 'split', 'into', 'its', 'basic', 'components', 'i', 
'wonder', 'what', 'will', 'come', 'next', 'any', 'guesses']

        创建分词器后,可以为示例创建词汇表。词汇表包含构成数据的唯一单词列表。虽然示例中没有重复项,但仍应将其删除。一个简单的例子是下面的句子:“我很酷,因为我很矮。词汇将是“我,是,酷,因为,短”。然后,这些词将按字母顺序排列:“我,因为,酷,我,短”。最后,它们将被分配一个整数:“am: 0, 因为: 1, cool: 2, i: 3, short: 4”。此过程在下面的函数中实现。

def build_vocab(data):# tokenize the data and remove duplicatesvocab = list(set(tokenize(data)))# sort the vocabularyvocab.sort()# assign an integer to each wordstoi = {word:i for i, word in enumerate(vocab)}return stoi# build the vocab
stoi = build_vocab(example)stoi 
{'a': 0,'an': 1,'any': 2,'basic': 3,'been': 4,'come': 5,'components': 6,'example': 7,'guesses': 8,'has': 9,'hello': 10,'i': 11,'into': 12,'is': 13,'its': 14,'next': 15,'of': 16,'paragraph': 17,'split': 18,'that': 19,'this': 20,'what': 21,'will': 22,'wonder': 23}

此词汇现在可用于将任何标记序列转换为其整数表示形式。

sequence = [stoi[word] for word in tokenize("I wonder what will come next!")]
sequence
[11, 23, 21, 22, 5, 15]

        下一步是创建嵌入层,它只不过是一个大小为 (vocab_size, d_model) 的随机值矩阵。这些值可以使用torch.rand生成。

# vocab size
vocab_size = len(stoi)# embedding dimensions
d_model = 3# generate the embedding layer
embeddings = torch.rand(vocab_size, d_model) # matrix of size (24, 3)
embeddings
tensor([[0.7629, 0.1146, 0.1228],[0.3628, 0.5717, 0.0095],[0.0256, 0.1148, 0.1023],[0.4993, 0.9580, 0.1113],[0.9696, 0.7463, 0.3762],[0.5697, 0.5022, 0.9080],[0.2689, 0.6162, 0.6816],[0.3899, 0.2993, 0.4746],[0.1197, 0.1217, 0.6917],[0.8282, 0.8638, 0.4286],[0.2029, 0.4938, 0.5037],[0.7110, 0.5633, 0.6537],[0.5508, 0.4678, 0.0812],[0.6104, 0.4849, 0.2318],[0.7710, 0.8821, 0.3744],[0.6914, 0.9462, 0.6869],[0.5444, 0.0155, 0.7039],[0.9441, 0.8959, 0.8529],[0.6763, 0.5171, 0.9406],[0.1294, 0.6113, 0.5955],[0.3806, 0.7946, 0.3526],[0.2259, 0.4360, 0.6901],[0.6300, 0.2691, 0.9785],[0.2094, 0.9159, 0.7973]])

        创建嵌入后,可以使用索引序列为每个标记选择适当的嵌入。原始序列的形状为 (6, ),值为 [11, 23, 21, 22, 5, 15]。

# embed the sequence
embedded_sequence = embeddings[sequence]embedded_sequence
tensor([[0.7110, 0.5633, 0.6537],[0.2094, 0.9159, 0.7973],[0.2259, 0.4360, 0.6901],[0.6300, 0.2691, 0.9785],[0.5697, 0.5022, 0.9080],[0.6914, 0.9462, 0.6869]])

        现在,六个标记中的每一个都被一个 3 元素向量替换;新形状为 (6, 3)。

        由于这些令牌中的每一个都有三个组件,因此它们可以在三个维度上映射。虽然此图显示了一个未经训练的嵌入矩阵,但经过训练的嵌入矩阵会像前面提到的书籍示例一样将相似的单词彼此靠近。

# visualize the embeddings in 3 dimensions
x, y, z = embedded_sequences[:, 0], embedded_sequences[:, 1], embedded_sequences[:, 2] 
ax = plt.axes(projection='3d')
ax.scatter3D(x, y, z)  

2.4 使用 PyTorch 模块进行嵌入

        由于 PyTorch 将用于实现转换器,因此 nn.可以分析嵌入模块。PyTorch将其定义为:

一个简单的查找表,用于存储固定字典和大小的嵌入。

此模块通常用于存储词嵌入并使用索引检索它们。模块的输入是索引列表,输出是相应的词嵌入。

        这准确地描述了在前面的示例中使用索引而不是独热向量时所执行的操作。

        至少,nn。嵌入需要vocab_size和嵌入维度,随着d_model的发展,将继续对其进行标注。提醒一下,这是模型维度的缩写。

        下面的代码创建了一个形状为 (24, 3) 的嵌入矩阵。

# vocab size
vocab_size = len(stoi) # 24# embedding dimensions
d_model = 3# create the embeddings
lut = nn.Embedding(vocab_size, d_model) # look-up table (lut)# view the embeddings
lut.state_dict()['weight']
tensor([[-0.3959,  0.8495,  1.4687],[ 0.2437, -0.3289, -0.5475],[ 0.9787,  0.7395,  2.0918],[-0.4663,  0.4056,  1.2655],[-1.0054,  1.4883, -0.1254],[-0.1028, -1.1913,  0.0523],[-0.2654, -1.0150,  0.4967],[-0.4653, -1.9941, -1.7128],[ 0.3894, -0.9368,  1.5543],[-1.1358, -0.2493,  0.6290],[-1.4935,  1.1509, -1.8723],[-0.0421,  1.2857, -0.4009],[-0.2699, -0.8918, -1.0352],[-1.3443,  0.4688,  0.1536],[ 0.3638,  0.1003, -0.2809],[ 1.4208, -0.0393,  0.7823],[-0.4473, -0.4605,  1.2681],[ 1.1315, -1.4704,  0.2809],[ 0.4270, -0.2067, -0.7951],[-1.0129,  0.0706, -0.3417],[ 1.4999, -0.2527,  0.4287],[-1.9280, -0.6485,  0.4660],[ 0.0670, -0.5822,  0.0996],[-0.7058,  0.2849,  1.1725]], grad_fn=<EmbeddingBackward0>)

        如果将与之前相同的索引序列 [11, 23, 21, 22, 5, 15] 传递给它,则输出将是一个 (6, 3) 矩阵,其中每个标记由其三维嵌入向量表示。索引必须采用张量的形式,数据类型为整数或长整型。

indices = torch.Tensor(sequence).long()embeddings = lut(indices)embeddings

        输出将是:

tensor([[ 0.7584,  0.2332, -1.2062],[-0.2906, -1.2168, -0.2106],[ 0.1837, -0.9425, -1.9011],[-0.7708, -1.1671,  0.2051],[ 1.5548,  1.0912,  0.2006],[-0.8765,  0.8829, -1.3169]], grad_fn=<EmbeddingBackward0>)

三、变压器中的嵌入层

        在原始论文中,嵌入层用于编码器和解码器。对nn的唯一补充。嵌入模块是一个标量。嵌入权重乘以 √(d_model)。这有助于在下一步中将嵌入添加到位置编码时保留基本含义。这实质上使位置编码相对较小,并减少了其对嵌入的影响。这个堆栈溢出线程更多地讨论了它。

        为了实现这一点,可以创建一个类;它将被称为嵌入,并利用PyTorch的nn。嵌入模块。此实现基于带注释的转换器。

class Embeddings(nn.Module):def __init__(self, vocab_size: int, d_model: int):"""Args:vocab_size:     size of vocabularyd_model:        dimension of embeddings"""# inherit from nn.Modulesuper().__init__()   # embedding look-up table (lut)                          self.lut = nn.Embedding(vocab_size, d_model)   # dimension of embeddings self.d_model = d_model                          def forward(self, x: Tensor):"""Args:x:              input Tensor (batch_size, seq_length)Returns:embedding vector"""# embeddings by constant sqrt(d_model)return self.lut(x) * math.sqrt(self.d_model)  

四、 前向传递

        此嵌入类的工作方式与 nn 相同。嵌入。下面的代码演示了它与前面示例中使用的单个序列的用法。

lut = Embeddings(vocab_size, d_model)lut(indices)
tensor([[-1.1189,  0.7290,  1.0581],[ 1.7204,  0.2048,  0.2926],[-0.5726, -2.6856,  2.4975],[-0.7735, -0.7224, -2.9520],[ 0.2181,  1.1492, -1.2247],[ 0.1742, -0.8531, -1.7319]], grad_fn=<MulBackward0>)

        到目前为止,每个嵌入中只使用了一个序列。但是,模型通常使用一批序列进行训练。这实质上是一个序列列表,这些序列被转换为它们的索引,然后嵌入。这可以在下图中看到。

# list of sequences (3, )
sequences = ["I wonder what will come next!","This is a basic example paragraph.","Hello, what is a basic split?"]

        虽然前面的示例很简陋,但它适用于序列批次。上图中显示的示例是具有三个序列的批处理;标记化后,每个序列由六个标记表示。标记化序列的形状为 (3, 6),与 (batch_size, seq_length) 相关。基本上,三个,六个字的句子。

# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]
tokenized_sequences
[['i', 'wonder', 'what', 'will', 'come', 'next'],['this', 'is', 'a', 'basic', 'example', 'paragraph'],['hello', 'what', 'is', 'a', 'basic', 'split']]

        然后可以使用词汇表将这些标记化序列转换为其索引表示形式。

# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]indexed_sequences
[[11, 23, 21, 22, 5, 15], [20, 13,  0,  3, 7, 17], [10, 21, 13,  0, 3, 18]]

最后,这些索引序列可以转换为可以通过嵌入层传递的张量。

# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()lut(tensor_sequences)
tensor([[[ 0.1348, -1.3131,  2.8429],[ 0.2866,  3.3650, -2.8529],[ 0.0985,  1.6396,  0.0191],[-3.8233, -1.5447,  0.5320],[-2.2879,  1.0203,  1.5838],[ 0.4574, -0.4881,  1.2095]],[[-1.7450,  0.2474,  2.4382],[ 0.2633,  0.3366, -0.4047],[ 0.2921, -1.6113,  1.1765],[-0.0132,  0.5255, -0.7268],[-0.5208, -0.9305, -1.1688],[ 0.4233, -0.7000,  0.2346]],[[ 1.6670, -1.7899, -1.1741],[ 0.0985,  1.6396,  0.0191],[ 0.2633,  0.3366, -0.4047],[ 0.2921, -1.6113,  1.1765],[-0.0132,  0.5255, -0.7268],[-0.4935,  3.2629, -0.6152]]], grad_fn=<MulBackward0>)

        输出将是一个 (3, 6, 3) 矩阵,它与 (batch_size、seq_length、d_model) 相关。本质上,每个索引令牌都被其相应的三维嵌入向量所取代。

        在进入下一节之前,了解此数据的形状(batch_size、seq_length d_model)非常重要:

  • batch_size与一次提供的序列数相关,通常为 16、32 或 64。
  • seq_length与标记化后每个序列中的单词或标记数相关。
  • d_model与嵌入每个令牌后的模型大小相关。

有关位置编码的文章是本系列的下一篇。

请不要忘记点赞和关注更多!:)

五、引用

  1. 图片来源:Will Koehrsen
  2. PyTorch 的嵌入模块
  3. 堆栈溢出讨论
  4. 带注释的变压器
  5. 变压器从零开始

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/110720.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微商城分销系统免费源码_微商城分销系统设计功能开发_OctShop

要使用微商城分销系统源码来搭建或制作自己的微商城分销系统平台&#xff0c;那么&#xff0c;首先你需要知道什么是分销&#xff1f;通俗点讲就是买家或消费者成为商家或平台的分销商&#xff0c;通过推荐给分享好友&#xff0c;或其他的各种推广方式&#xff0c;如二维码&…

生产环境部署与协同开发 Git

目录 一、前言——Git概述 1.1 Git是什么 1.2 为什么要使用Git 什么是版本控制系统 1.3 Git和SVN对比 SVN集中式 Git分布式 1.4 Git工作流程 四个工作区域 工作流程 1.5 Git下载安装 1.6 环境配置 设置用户信息 查看配置信息 二、git基础 2.1 本地初始化仓库 ​编辑…

分段三次hermit插值

保形三次hermit插值 一、算法实现 一、插值函数建立 设函数 y F ( x ) yF(x) yF(x)在区间 [ a , b ] [a,b] [a,b]上有定义&#xff0c;且已知在离散点 a x 0 < x 1 < . . . < x n b ax_0<x_1<...<x_n b ax0​<x1​<...<xn​b上的值 y 0 , y…

Linux 查看当前文件夹下的文件大小

1.直接查看: ll 或者 ls -la #查看文件大小&#xff0c;以kb为单位 ll#查看文件大小&#xff0c;包含隐藏的文件&#xff0c;以kb为单位 ls -la2.以 M 或者 G 为单位查看&#xff0c;根据文件实际大小进行合适的单位展示 du -sh *

k8s集群搭建

文章目录 前言一、前置准备1.1 虚拟机准备1.2 关闭swap分区1.3 将桥接的IPv4流量传递到iptables链1.4 开启ipvs 二、容器化环境和组件安装2.1 docker安装2.2 设置docker加速镜像器2.4 设置yum镜像源2.5 安装kubeadm、kubelet和kubectl 三、集群搭建3.1 安装k8s所需镜像3.2 在ha…

LAMP 配置与应用

LAMP 架构的组成 LAM(M)P&#xff1a; L&#xff1a;linux A&#xff1a;apache (httpd) M&#xff1a;mysql, mariadb P&#xff1a;php, perl, python apache的功能&#xff1a; 第一&#xff1a;处理http的请求、构建响应报文等自身服务&#xff1b; 第二&#xff1a…

【C#学习笔记】数据类中常用委托及接口——以List<T>为例

文章目录 List\<T\>/LinkedList \<T\>为什么是神&#xff1f;&#xff08;泛型为什么是神&#xff09;一些常见&#xff0c;通用的委托和接口ComparisonEnumerator List<T>/LinkedList <T>为什么是神&#xff1f;&#xff08;泛型为什么是神&#xff0…

数据结构(Java实现)-栈和队列

栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。 先进后出 栈的使用 栈的模拟实现 上述的主要代码 public class MyStack {private int[] elem;private int usedSize;public MyStack() {this.elem new int[5];}Overridepublic …

-9501 MAL系统没有配置或者服务器不是企业版(dm8达梦数据库)

dm8达梦数据库 -9501 MAL系统没有配置或者服务器不是企业版&#xff09; 环境介绍1 环境检查2 问题原因 环境介绍 搭建主备集群时&#xff0c;遇到报错-9501 MAL系统没有配置或者服务器不是企业版 1 环境检查 检查dmmal.ini配置文件权限正确 dmdba:dinstall&#xff0c;内容正…

3.RabbitMQ 架构以及 通信方式

一、RabbitMQ的架构 RabbitMQ的架构可以查看官方地址 可以看出RabbitMQ中主要分为三个角色&#xff1a; Publisher&#xff1a;消息的发布者&#xff0c;将消息发布到RabbitMQ中的ExchangeRabbitMQ服务&#xff1a;Exchange接收Publisher的消息&#xff0c;并且根据Routes策…

安装虚拟机

软硬件准备 软件&#xff1a;推荐使用VMwear&#xff0c;我用的是VMwear 12 镜像&#xff1a;CentOS7 ,如果没有镜像可以在官网下载 &#xff1a;http://isoredirect.centos.org/centos/7/isos/x86_64/CentOS-7-x86_64-DVD-1804.iso 硬件&#xff1a;因为是在宿主机上运行虚拟…

Android屏幕显示 android:screenOrientation configChanges 处理配置变更

显示相关 屏幕朝向 https://developer.android.com/reference/android/content/res/Configuration.html#orientation 具体区别如下&#xff1a; activity.getResources().getConfiguration().orientation获取的是当前设备的实际屏幕方向值&#xff0c;可以动态地根据设备的旋…

Maven之hibernate-validator 高版本问题

hibernate-validator 高版本问题 hibernate-validator 的高版本&#xff08;邮箱注解&#xff09;依赖于高版本的 el-api&#xff0c;tomcat 8 的 el-api 是 3.0&#xff0c;满足需要。但是 tomcat 7 的 el-api 只有 2.2&#xff0c;不满足其要求。 解决办法有 2 种&#xff…

RocketMQ mqadmin java springboot python 调用笔记

命令 mqadmin命令列表 yeqiangyeqiang-MS-7B23:/opt/rocketmq-all-5.1.3-bin-release$ sh bin/mqadmin The most commonly used mqadmin commands are:updateTopic Update or create topicdeleteTopic Delete topic from broker and NameServer.…

【深度学习_TensorFlow】过拟合

写在前面 过拟合与欠拟合 欠拟合&#xff1a; 是指在模型学习能力较弱&#xff0c;而数据复杂度较高的情况下&#xff0c;模型无法学习到数据集中的“一般规律”&#xff0c;因而导致泛化能力弱。此时&#xff0c;算法在训练集上表现一般&#xff0c;但在测试集上表现较差&…

亿发浙江生产工厂信息化建设管理平台,实现生产智能化、数字化

在全球化、科技深刻变革的时代&#xff0c;浙江省信息化建设正迎来新的发展机遇。以物联网、人工智能大数据、为代表的新技术应用&#xff0c;为人类社会带来了智能、便捷&#xff0c;也标志着新一代信息化浪潮已经到来。特别是在生产型企业中&#xff0c;智能制造是生产型企业…

运用Python解析HTML页面获取资料

在网络爬虫的应用中&#xff0c;我们经常需要从HTML页面中提取图片、音频和文字资源。本文将介绍如何使用Python的requests库和BeautifulSoup解析HTML页面&#xff0c;获取这些资源。 一、环境准备 首先&#xff0c;确保您已经安装了Python环境。接下来&#xff0c;我们需要安…

HUT23级训练赛

目录 A - tmn学长的字符串1 B - 帮帮神君先生 C - z学长的猫 D - 这题用来防ak E - 这题考察FFT卷积 F - 这题考察二进制 G - 这题考察高精度 H - 这题考察签到 I - 爱派克斯&#xff0c;启动! J - tmn学长的字符串2 K - 秋奕来买瓜 A - tmn学长的字符串1 思路&#x…

CSS中如何实现背景图片的平铺和定位?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 平铺背景图片⭐ 背景图片定位⭐ 同时设置平铺和定位⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&#xff01;这个专栏是…

3D点云处理:基于2D边缘提取的方法提取3D点云边缘(占位待补充)

文章目录 0. 实现效果 微信&#xff1a;dhlddx B站演示视频 0. 实现效果