Transformer模型详细步骤

Transformer模型是nlp任务中不能绕开的学习任务,我将从数据开始,每一步骤都列举出来,然后对应重点的代码进行讲解

-------------------------------------------------------------------------------------------------------------

Transformer模型是基于注意力机制的一种深度学习架构,最早由Vaswani等人在2017年提出,主要用于自然语言处理(NLP)任务。它不同于传统的循环神经网络(RNN)或卷积神经网络(CNN),因为它完全依赖于注意力机制,不需要通过时间步长来处理序列数据,从而可以更高效地并行处理数据。

大名鼎鼎的transform一经出现就席卷了各个方面,

transform原论文:

Attention Is All You Need

论文网址:https://arxiv.org/pdf/1706.03762

核心组件介绍

Transformer模型主要包括以下几个部分:

  1. 输入嵌入(Input Embeddings)
  2. 位置编码(Positional Encoding)
  3. 多头自注意力机制(Multi-Head Self-Attention)
  4. 前馈神经网络(Feedforward Neural Network)
  5. 编码器(Encoder)和解码器(Decoder)结构
  6. 输出层(Output Layer)

以句子“我喜欢小狗”为例,详细展示Transformer模型中的每一步及其对应的矩阵变化。假设每个单词的嵌入维度为4,句子长度为4。

步骤 1:输入嵌入

假设通过一个词嵌入矩阵(Embedding Matrix)将每个词转化为一个4维的嵌入向量,嵌入后得到的矩阵X∈R^4×4,词嵌入矩阵是通过预训练得到的。例如,Word2Vec和GloVe等模型已经在大规模文本语料上训练好,提供了每个单词的嵌入向量。

Word2Vec 简介

Word2Vec 是一种将词汇表示为向量的技术,它通过神经网络模型将词映射到连续向量空间中,能够捕捉词与词之间的语义关系。Word2Vec 有两种主要的模型结构:

  1. CBOW(Continuous Bag of Words):基于上下文词来预测中心词。
  2. Skip-gram:基于中心词来预测上下文词。

Word2Vec 的原理

Word2Vec 的核心思想是基于词的上下文来学习词向量。在大规模语料库中,词汇共现的模式可以用来推测它们之间的语义相似性。词向量模型旨在使得语义相似的词在向量空间中彼此接近。

CBOW 模型

CBOW 通过上下文词预测中心词。例如,在“我喜欢小狗”这个句子中,假设要预测“喜欢”,则上下文词为“我”和“小狗”。

上下文词嵌入平均值:对于给定的上下文词​,计算它们词嵌入向量的平均值:

预测中心词的概率分布:我们用上下文的平均向量通过 softmax 函数来预测中心词的概率分布:

通过计算每个词向量和上下文向量的内积来衡量词语匹配的可能性,最后通过 softmax 归一化成概率分布。

损失函数:CBOW 模型的目标是最大化所有中心词的预测概率,通常使用交叉熵损失:

Skip-gram 模型

Skip-gram 模型与 CBOW 模型相反,它使用中心词预测上下文词。例如,给定中心词“喜欢”,我们预测上下文词“我”和“小狗”。

公式:

Skip-gram 模型的计算步骤与 CBOW 类似,只是这里使用中心词来预测每个上下文词。

上下文词的预测:给定中心词,我计算上下文词的概率:

损失函数:Skip-gram 模型的目标是最大化上下文词的预测概率,损失函数为:

在实际应用中,由于词汇表 V的大小可能非常大,直接计算 softmax 的开销非常高。为了解决这个问题,Word2Vec 引入了 负采样 技术。

负采样的主要思想是,只对正样本(真实的上下文词)和一小部分负样本(随机选择的非上下文词)进行训练,而不是对整个词汇表计算 softmax。负采样的损失函数为:

Word2Vec 的训练

  1. 初始化词向量矩阵,通常是随机生成的。
  2. 通过优化损失函数(如交叉熵或负采样)来更新词向量。
  3. 最终训练完成后,模型会输出每个词在向量空间中的表示,语义相似的词在向量空间中距离较近。

示例:

假设条件:

  • 词汇表大小 V=5(假设词汇表中只有 "我"、"喜欢"、"小狗"、"吃"、"饭" 五个词)。
  • 词向量维度 d=4
  • Skip-gram 模型,中心词为 "喜欢",上下文词为 "我" 和 "小狗"。

1. 初始化嵌入矩阵

首先,词嵌入矩阵是随机初始化的,用来表示词汇表中每个词的向量表示。假设嵌入矩阵:

2. Skip-gram 中心词和上下文词

在 Skip-gram 模型中,中心词 "喜欢" 的嵌入向量是通过查找嵌入矩阵

3. 预测上下文词

使用中心词的嵌入向量来预测上下文词。Skip-gram 的目标是让中心词和真实上下文词的相似度最大化,同时最小化中心词与负样本(随机选取的词)的相似度。预测上下文词的概率可以通过计算中心词嵌入和上下文词嵌入的点积:

对于 "我":

对于”狗“:

通过点积可以计算出中心词和上下文词之间的相似度分数。为了得到概率,我们通常通过 softmax 函数对这些相似度进行归一化。

4. Softmax 计算

为了预测 "我" 和 "小狗" 的概率,需要计算 softmax:

5. 负采样

为了简化计算,Skip-gram 模型引入了负采样。假设我们随机选择 "吃"作为负样本,计算它们与中心词 "喜欢" 的点积:

接下来,将正样本和负样本的结果输入到 sigmoid 函数中进行优化。

设单词xi 是输入的单词,其嵌入向量为 e(xi)

  • "我" -> 0
  • "喜欢" -> 1
  • "小狗" -> 2

词向量矩阵的学习过程如下:

  1. 初始化:开始时,嵌入矩阵的每个单词向量可以是随机初始化的。
  2. 前向传播:通过嵌入矩阵将单词转换为向量,并传递到模型的下一层。
  3. 损失计算:模型的输出与目标标签计算损失。
  4. 反向传播:通过计算损失函数的梯度来更新嵌入矩阵。
  5. 优化:使用优化算法(如SGD、Adam)来更新嵌入矩阵,直到模型收敛。
生成方式:

词向量矩阵初始时随机生成,但在训练过程中会根据反向传播更新。其目的是让相似意义的单词在向量空间中靠得更近。具体来说:

假设矩阵如下所示:

"我" -> [0.2, 0.4, 0.1, 0.3]

"喜欢" -> [0.6, 0.8, 0.5, 0.9]

"小狗" -> [0.7, 0.2, 0.9, 0.1]

 

步骤 2:位置编码(Positional Encoding)

由于Transformer没有时间步长的概念,因此需要加入位置信息来帮助模型理解序列顺序。位置编码使用一些数学公式,比如正弦和余弦函数,将位置信息加入到嵌入向量中。位置编码矩阵是固定的,不需要训练。它根据输入序列的位置和维度生成。计算公式为:

 其中 pos 是位置,i 是维度索引,d 是嵌入的总维度。

  • 每个位置 pospospos 对应的向量由正弦余弦函数的组合构成。位置越靠前的单词,它的编码数值变化越剧烈,越往后的单词,数值变化就会越缓慢。这种设计让不同位置的编码在各个维度上有所区分。
  • 正弦和余弦函数的周期性特性也使得模型可以容易地捕捉到不同单词之间的位置差异。

假设句子是“我喜欢小狗”,词嵌入矩阵初始化为:

 根据位置编码公式,计算第 0、1、2 个位置的编码

然后我们将词向量和位置编码向量相加,得到最终输入 Transformer 的向量: 

位置编码与词向量的关系

  • 词向量:是通过嵌入矩阵(通常是随机初始化后经过训练得到的)来表示单词的语义。词向量中没有位置信息。

  • 位置编码:是为了让模型知道每个词在句子中的位置。它不会改变词向量的语义,而是将位置信息叠加到词向量中。位置编码矩阵是固定的,不需要训练

#对应代码
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(1), :]

 

步骤 3:多头自注意力机制(Multi-Head Self-Attention)

这是Transformer的核心部分,自注意力机制计算的是句子中每个单词和其他单词的相关性。首先将输入嵌入分别映射到三个不同的空间:查询(Query)、键(Key)和值(Value)。然后计算每对单词之间的注意力权重,最终通过这些权重加权求和得到每个单词的新表示。

 缩放因子 ​:这是一个常数,确保内积的尺度合适。它不会被训练,也不是随机生成的。

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_k = d_model // num_headsself.num_heads = num_heads# 需要训练的矩阵:W_Q, W_K, W_Vself.W_Q = nn.Linear(d_model, d_model)self.W_K = nn.Linear(d_model, d_model)self.W_V = nn.Linear(d_model, d_model)self.W_O = nn.Linear(d_model, d_model)  # 最终的输出权重矩阵def forward(self, X):batch_size, seq_len, d_model = X.shape# 线性变换得到 Q, K, VQ = self.W_Q(X)  # (batch_size, seq_len, d_model)K = self.W_K(X)V = self.W_V(X)# 将 Q, K, V 分成多个头Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 计算缩放点积注意力attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)attention_output = torch.matmul(attention_weights, V)# 将多个头的输出合并attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)return self.W_O(attention_output)  # 通过线性层输出

步骤 4:前馈神经网络(Feedforward Neural Network)

每个注意力层后面接一个前馈神经网络,通常由两个线性变换和一个ReLU激活函数组成。

前馈神经网络由两个全连接层构成,各自有权重和偏置矩阵:

 

class FeedForward(nn.Module):def __init__(self, d_model, d_ff):super(FeedForward, self).__init__()self.linear1 = nn.Linear(d_model, d_ff)self.relu = nn.ReLU()self.linear2 = nn.Linear(d_ff, d_model)def forward(self, x):return self.linear2(self.relu(self.linear1(x)))

步骤 5:编码器(Encoder)和解码器(Decoder)

Transformer的编码器由多个层堆叠而成,每一层都包含一个多头自注意力机制和前馈神经网络。解码器除了这些模块外,还包含一个额外的注意力层,用于接收编码器的输出。

步骤 6:输出层(Output Layer)

最终的输出通常会通过一个线性层映射到所需的输出维度,比如词汇表大小(用于机器翻译)或分类任务中的类别数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/425414.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

物品识别——基于python语言

目录 1.物品识别 2.模型介绍 3.文件框架 4.代码示例 4.1 camera.py 4.2 interaction.py 4.3 object_detection.py 4.4 main.py 4.5 运行结果 5.总结 1.物品识别 该项目使用Python,OpenCV进行图像捕捉,进行物品识别。我们将使用YOLO&#xff08…

re题(23)BUUFCTF-[FlareOn4]login

BUUCTF在线评测 (buuoj.cn) 下载后打开看到是一个txt和一个html 分别打开看看,txt是提示,html应该就是要破解的网页 打开网页,查看源代码 找到程序,变灰的部分是关键,是指如果是前13个字母就加13,如果是…

小程序开发设计-第一个小程序:注册小程序开发账号②

上一篇文章导航: 小程序开发设计-小程序简介①-CSDN博客https://blog.csdn.net/qq_60872637/article/details/142217803?sharetypeblogdetail&sharerId142217803&sharereferPC&sharesourceqq_60872637&spm1011.2480.3001.8118 须知:不…

C++设计模式——Prototype Pattern原型模式

一,原型模式的定义 原型模式是一种创建型设计模式,它允许通过克隆已有对象来创建新对象,从而无需调用显式的实例化过程。 原型模式的设计,使得它可以创建一个与原型对象相同或类似的新对象,同时又可以减少对象实例化…

Rust Windows下编译 静态链接VCRuntime140.dll

Rust 编译出来的exe默认动态链接VC运行库,分发电脑上需要安装有Microsoft Visual C Redistributable for Visual Studio 2015运行库。 编译时能静态链接进去,就省去客户端未安装运行库的问题。方法如下: 只需在当前根目录下新建.cargo\config.toml&#…

【可视化大屏系列】数据列表自动滚动效果

要实现列表的自动滚动效果,这里提供两种解决方案: 1.vue插件 官方文档:链接: vue-seamless-scroll (1)安装依赖 npm install vue-seamless-scroll --save(2)全局注册(main.js中&a…

【CTF Web】BUUCTF BUU UPLOAD COURSE 1 Writeup(文件上传+PHP+文件包含漏洞)

BUU UPLOAD COURSE 1 1 上课用~ 点击启动靶机。 解法 疑似存在文件包含漏洞。 http://15a5666e-1796-4f76-b892-0b69cf97df8e.node5.buuoj.cn:81/index.php?fileupload.php查看网页源代码。判断是后端检查。 <!DOCTYPE html> <html lang"zh-cn"> &…

多目标优化算法求解LSMOP(Large-Scale Multi-Objective Optimization Problem)测试集,MATLAB代码

LSMOP&#xff08;Large-Scale Multi-Objective Optimization Problem&#xff09;测试集是用于评估大规模多目标优化算法性能的一组标准测试问题。这些测试问题通常具有大量的决策变量和目标函数&#xff0c;旨在模拟现实世界中的复杂优化问题。 LSMOP测试集包含多个子问题&am…

element-plus的面包屑组件el-breadcrumb

面包屑组件主要用来显示当页面路径&#xff0c;以及快速返回之前的页面。 涉及2个组件 el-breadcrumb 和el-breadcrumb-item, el-breadcrumb的spearator指定item的分隔符 el-breadcrumb-item的to和replace属性和vue-router的一致&#xff0c;需要结合vue_router一起使用 用法…

通过python提取PDF文件指定页的图片

整体思路 要从 PDF 文件中提取指定页和指定位置的图片&#xff0c;可以分几个步骤来实现&#xff1a; 1.1 准备所需工具与库 在 Python 中处理 PDF 和图像时&#xff0c;需要使用几个库&#xff1a; PyMuPDF (fitz)&#xff1a;用于读取和处理 PDF 文件&#xff0c;可以精确…

RabbitMQ高级篇,进阶内容

强烈建议在看本篇博客之前快速浏览文章&#xff1a;RabbitMQ基础有这一篇就够了 RabbitMQ高级篇 0. 前言1. 发送者的可靠性1.1 生产者重试机制1.2 生产者确认机制1.3 实现生产者确认 2. MQ的可靠性2.1 MQ持久化2.2 LazyQueue 3. 消费者的可靠性3.1 消费者确认机制3.2 失败重试策…

Web植物管理系统-下位机部分

本节主要展示上位机部分&#xff0c;采用BSP编程&#xff0c;不附带BSP中各个头文件的说明&#xff0c;仅仅是对main逻辑进行解释 main.c 上下位机通信 通过串口通信&#xff0c;有两位数据验证头&#xff08;verify数组中保存对应的数据头 0xAA55) 通信格式 上位发送11字节…

STM32外设之LTDC/DMA2D—液晶显示(野火)

文章目录 显示屏有几种?基本参数控制?显存 LTDC 液晶控制器LTDC 结构框图LTDC 初始化结构体 LTDC_InitTypeDefLTDC 层级初始化结构体 DMA2D 图形加速器DMA2D 初始化结构体 要了解什么 屏幕是什么&#xff0c;有几种屏&#xff0c;有什么组成。 怎么控制&#xff0c;不同屏幕控…

Linux:RPM软件包管理以及Yum软件包仓库

挂载光驱设备 RPM软件包管理 RPM软件包简介 区分软件名和软件包名 软件名&#xff1a;firefox 软件包名&#xff1a;firefox-52.7.0-1.el7.centos.x86_64.rpm 查询软件信息 查询软件&#xff08;参数为软件名&#xff09; ]# rpm -qa #当前系统中所有已安装的软件包 ]# r…

滑坡落石检测数据集

滑坡落石检测数据集 1500张 滑坡落石 带标注 voc yolo 项目背景&#xff1a; 滑坡落石是地质灾害中的一种常见现象&#xff0c;它对人类生活和基础设施构成了严重威胁。及时准确地检测滑坡落石对于预防灾害发生、减少损失至关重要。传统的检测方法往往依赖于人工巡查&#xff…

蓝桥杯—STM32G431RBT6按键的多方式使用(包含软件消抖方法精讲)从原理层面到实际应用(一)

新建工程教程见http://t.csdnimg.cn/JySLg 点亮LED教程见http://t.csdnimg.cn/Urlj5 末尾含所有代码 目录 按键原理图 一、按键使用需要解决的问题 1.抖动 1.什么是抖动 2.抖动类型 3.如何去消除抖动 FIRST.延时函数消抖&#xff08;缺点&#xff1a;浪费CPU资源&#xff…

transformer模型进行英译汉,汉译英

上面是在测试集上的表现 下面是在训练集上的表现 上面是在训练集上的评估效果 这是在测试集上的评估效果,模型是transformer模型,模型应该没问题,以上的是一个源序列没加结束符和加了结束符的情况。 transformer源序列做遮挡填充的自注意力,这就让编码器的输出中每个token的语…

第312题|二重积分求旋转体体积(二)|武忠祥老师每日一题

解题思路&#xff1a;先画出图像&#xff0c;再利用旋转体体积计算公式进行解题。 1. 旋转体体积计算公式&#xff1a; 2.点到直线计算公式&#xff1a; 有了上面两条知识储备之后我们开始计算。 第一步&#xff1a;先计算出点到直线的距离&#xff1a; ymx&#xff0c;y-mx…

web开发 之 HTML、CSS、JavaScript、以及JavaScript的高级框架Vue(学习版2)

一、前言 接下来就是来解决这些问题 二、 Ajax 1.ajax javscript是网页三剑客之一&#xff0c;空用来控制网页的行为的 xml是一种标记语言&#xff0c;是用来存储数据的 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-…

JVM字节码与局部变量表

文章目录 局部变量表javap字节码指令分类 指令指令数据类型前缀加载和存储指令加载常量算术指令其他指令 字节码示例说明 局部变量表 每个线程的帧栈是独立的&#xff0c;每个线程中的方法调用会产生栈帧&#xff0c;栈帧中保存着方法执行的信息&#xff0c;例如局部变量表。 …