【自然语言处理】Encoder-Decoder模型中Attention机制的引入

Encoder-Decoder 模型中引入 Attention 机制,是为了改善基本Seq2Seq模型的性能,特别是当处理长序列时,传统的Encoder-Decoder模型容易面临信息压缩的困难。Attention机制可以帮助模型动态地选择源序列中相关的信息,从而提高翻译等任务的质量。

一、为什么需要Attention机制?

在基本的 Encoder-Decoder 模型中,Encoder将整个源句子的所有信息压缩成一个固定大小的向量(上下文向量),然后Decoder使用这个向量来生成目标序列。这个单一的上下文向量对于较短的句子可能足够,但对于较长的句子,模型可能无法有效捕捉到整个句子中所有重要的信息。这样容易导致信息丢失,尤其是当句子很长时,Decoder在生成目标词时可能无法获取到源句子的细节信息。

二、Attention机制的核心思想

Attention机制的核心思想是:在每个时间步生成目标单词时,Decoder不再依赖于固定的上下文向量,而是能够通过“注意力”权重,动态地从源句子的所有隐状态中选择最相关的部分。这样,Decoder每生成一个目标词时,能够更好地“关注”源句子中与当前生成词最相关的部分。

三、Attention机制的工作流程

在每一步解码时,Attention机制会根据Decoder的当前状态计算出一组权重,表示源句子中各个位置的隐状态对当前解码步骤的重要性。这些权重用于加权源句子的隐状态,以得到一个上下文向量,这个上下文向量会与当前Decoder的隐状态一起用于生成下一个目标词。由于它跨越两个序列:源语言序列(编码器输出)作为 Key 和 Value;目标语言序列(解码器的当前状态)作为 Query,因此也叫交叉注意力

Attention的具体步骤如下:

  1. 计算注意力权重

    • 对于Decoder的每一步(生成每个目标词时),通过Decoder的当前隐状态和源句子每个时间步的隐状态来计算注意力权重。
    • 这些权重表示源句子中每个位置的重要性,可以使用加性Attention点积Attention来计算。
  2. 计算上下文向量

    • 通过将注意力权重与源句子的隐状态进行加权平均,得到一个新的上下文向量。
    • 这个上下文向量包含了源句子中当前对Decoder最重要的信息。
  3. 解码下一步

    • 将新的上下文向量与当前Decoder的隐状态结合,用于生成当前的目标词。

四、Attention机制的公式

对于每个时间步 t:

  1. 计算注意力得分:通常使用Decoder当前的隐状态 ht 和源句子每个位置的隐状态 hs 计算注意力得分,可以通过以下公式计算:

在这里插入图片描述

常见的 score 函数有加性(Bahdanau Attention)和点积(Luong Attention):

  • 加性Attention:使用一个简单的前馈网络对 ht 和 hs 进行线性变换并加和。
  • 点积Attention:直接计算 ht 和 hs 的点积。
  1. 计算注意力权重:对得分 et,s​ 进行Softmax操作,得到权重:

在这里插入图片描述

这些权重 αt,s 表示源句子中各个位置对当前解码的影响力。

  1. 计算上下文向量:使用注意力权重对源句子的隐状态进行加权平均,得到上下文向量 ct:

在这里插入图片描述

  1. 生成下一个词:将上下文向量 ct 与Decoder的隐状态 ht 结合,生成下一个词。

五、引入Attention机制的Encoder-Decoder代码实现

以下是一个带有 Attention 机制的 Encoder-Decoder 模型的简化实现,使用 PyTorch 进行构建。

import torch
import torch.nn as nn# Encoder模型
class Encoder(nn.Module):def __init__(self, input_size, embedding_dim, hidden_size):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)def forward(self, src):embedded = self.embedding(src)  # [batch_size, src_len, embedding_dim]outputs, (hidden, cell) = self.lstm(embedded)  # [batch_size, src_len, hidden_size]return outputs, hidden, cell# Attention模型
class Attention(nn.Module):def __init__(self, hidden_size):super(Attention, self).__init__()self.attn = nn.Linear(hidden_size * 2, hidden_size)self.v = nn.Parameter(torch.rand(hidden_size))def forward(self, hidden, encoder_outputs):src_len = encoder_outputs.shape[1]hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch_size, src_len, hidden_size]energy = torch.sum(self.v * energy, dim=2)  # [batch_size, src_len]return torch.softmax(energy, dim=1)  # [batch_size, src_len]# Decoder模型
class Decoder(nn.Module):def __init__(self, output_size, embedding_dim, hidden_size):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim + hidden_size, hidden_size, batch_first=True)self.fc_out = nn.Linear(hidden_size * 2, output_size)self.attention = Attention(hidden_size)def forward(self, input_token, hidden, cell, encoder_outputs):input_token = input_token.unsqueeze(1)  # [batch_size, 1]embedded = self.embedding(input_token)  # [batch_size, 1, embedding_dim]# 计算注意力权重attn_weights = self.attention(hidden[-1], encoder_outputs)  # [batch_size, src_len]# 使用注意力权重对encoder输出进行加权平均attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, hidden_size]# 将注意力上下文向量和嵌入层输入拼接lstm_input = torch.cat((embedded, attn_applied), dim=2)  # [batch_size, 1, embedding_dim + hidden_size]# 通过LSTMoutput, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))  # [batch_size, 1, hidden_size]# 生成最终输出output = torch.cat((output.squeeze(1), attn_applied.squeeze(1)), dim=1)  # [batch_size, hidden_size * 2]prediction = self.fc_out(output)  # [batch_size, output_size]return prediction, hidden, cell# Seq2Seq模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, tgt, teacher_forcing_ratio=0.5):batch_size = tgt.shape[0]target_len = tgt.shape[1]target_vocab_size = self.decoder.fc_out.out_featuresoutputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)encoder_outputs, hidden, cell = self.encoder(src)input_token = tgt[:, 0]for t in range(1, target_len):output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_outputs)outputs[:, t, :] = outputtop1 = output.argmax(1)input_token = tgt[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1return outputs
代码说明:
  1. Encoder

    • 编码源句子,生成隐状态和输出序列。
    • 输出序列会在注意力机制中使用。
  2. Attention

    • Attention 模型根据当前隐状态和Encoder输出计算注意力权重。
  3. Decoder

    • 使用Attention得到的注意力权重对Encoder输出进行加权平均,得到上下文向量。
    • Decoder在当前时间步会将 当前输入(上一个时间步生成的词)、上一个时间步的隐状态 和 注意力上下文向量 拼接起来,输入到LSTM或GRU中,更新隐状态并生成当前时间步的输出。
  4. Seq2Seq

    • 将Encoder和Decoder结合,逐步生成目标序列。
    • 使用了教师强制机制来控制训练时的输入。
Decoder代码详细解释:
  1. attn_weights = self.attention(hidden[-1], encoder_outputs):

    • hidden[-1] 是Decoder当前时间步的最后一层隐状态(对于多层LSTM来说)。encoder_outputs 是Encoder所有时间步的输出。
    • 调用 self.attention 计算当前时间步的注意力权重。
  2. attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs):

    • attn_weights 是注意力权重,形状为 [batch_size, src_len]
    • unsqueeze(1) 将其变为 [batch_size, 1, src_len],然后与 encoder_outputs(形状为 [batch_size, src_len, hidden_size])进行批量矩阵乘法(torch.bmm)。
    • 这样得到的结果 attn_applied 是加权后的上下文向量,形状为 [batch_size, 1, hidden_size],表示根据注意力权重加权后的源句子信息。
  3. torch.cat((embedded, attn_applied), dim=2):

    • 将Decoder的当前输入(嵌入表示)和上下文向量拼接在一起,输入到LSTM中。

六、总结:

Attention机制的引入,允许Decoder在生成每个目标词时,能够动态地根据源句子的不同部分调整注意力,使得模型能够处理更长的序列,并提高生成结果的准确性。Attention机制在机器翻译等任务中取得了显著的效果,并且为之后的Transformer等模型的出现奠定了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/448524.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解AI绘画扩散原理-更好掌握AI绘画工具

AI绘画正在成为一种热门的创作工具&#xff0c;壁纸、模特、真人转二次元、艺术字、二维码、设计图、老照片修复、高清修复等&#xff0c;越来越多的使用场景&#xff0c;AI绘画让没有美术基础的人也能够借助工具获得自己想要的美术图片。 AI绘画的核心是“生成模型”&#xf…

插件分享|沉浸式翻译

在这个全球化的时代&#xff0c;语言不再是交流的障碍。但你是否曾经因为一篇外文网页、一份PDF文档或是一段视频字幕而苦恼不已&#xff1f;现在&#xff0c;一款名为“沉浸式翻译”的网页翻译插件&#xff0c;将彻底改变你的翻译体验&#xff01;&#xff08;文末附安装地址&…

开源医疗大模型Llama3-Aloe-8B-Alpha,性能超越 MedAlpaca 和 PMC-LLaMA

前言 近年来&#xff0c;大型语言模型 (LLM) 在医疗领域展现出巨大潜力&#xff0c;能够帮助医生和研究人员更快地获取信息、分析数据&#xff0c;并提高医疗服务效率。然而&#xff0c;目前市场上大多数医疗 LLM 都是闭源模型&#xff0c;限制了其在学术研究和应用领域的推广…

基于Arduino的仿生面具

DIY 万圣节恐怖惊喜&#xff1a;自制动态眼动和声音感应的仿生面具 引言 万圣节即将来临&#xff0c;你是否准备好制作一些既诡异又迷人的装饰来增添节日气氛呢&#xff1f;今天&#xff0c;我们将一起探索如何使用3D打印、伺服电机、PIR传感器和DFPlayer MP3模块来制作一个动…

【黑马redis高级篇】持久化

//来源[01,05]分布式缓存 除了黑马&#xff0c;还参考了别的。 目录 1.单点redis问题及解决方案2.为什么需要持久化&#xff1f;3.Redis持久化有哪些方式呢&#xff1f;为什么我们需要重点学RDB和AOF&#xff1f;4.RDB4.1 定义4.2 触发方式4.2.1手动触发save4.2.2被动触发bgsa…

STM32 ADC学习日记

STM32 ADC学习日记 1. ADC简介 ADC 即模拟数字转换器&#xff0c;英文详称 Analog-to-digital converter&#xff0c;可以将外部的模拟信号转换为数字信号。 STM32F103 系列芯片拥有 3 个 ADC&#xff08;C8T6 只有 2 个&#xff09;&#xff0c;这些 ADC 可以独立使用&…

《中国林业产业》是什么级别的期刊?是正规期刊吗?能评职称吗?

​问题解答 问&#xff1a;《中国林业产业》是不是核心期刊&#xff1f; 答&#xff1a;不是&#xff0c;是知网收录的正规学术期刊。 问&#xff1a;《中国林业产业》级别&#xff1f; 答&#xff1a;国家级。主管单位&#xff1a;国家林业和草原局 …

【Linux】system V进程间通信--共享内存,消息队列,信号量

目录 共享内存 基本原理 创建共享内存 共享内存创建好后&#xff0c;我们可以查询共享内存&#xff0c;验证一下是否创建成功&#xff1b; 删除共享内存 共享内存的挂接 实现通信 消息队列&#xff08;了解&#xff09; 消息队列概念 消息队列接口 操作指令 信号量…

从MySQL到OceanBase离线数据迁移的实践

本文作者&#xff1a;玉璁&#xff0c;OceanBase 生态产品技术专家。工作十余年&#xff0c;一直在基础架构与中间件领域从事研发工作。现负责OceanBase离线导数产品工具的研发工作&#xff0c;致力于为 OceanBase 建设一套完善的生态工具体系。 背景介绍 在互联网与云数据库技…

番外篇 | 史上最全的关于CV的一些经典注意力机制代码汇总

前言:Hello大家好,我是小哥谈。注意力是人类认知系统的核心部分,它允许我们在各种感官输入中筛选和专注于特定信息。这一能力帮助我们处理海量的信息,关注重要的事物,而不会被次要的事物淹没。受到人类认知系统的启发,计算机科学家开发了注意力机制,这种机制模仿人类的这…

鸿蒙跨设备协同开发04——跨设备剪切板开发

如果你也对鸿蒙开发感兴趣&#xff0c;加入“Harmony自习室”吧&#xff01;扫描下方名片&#xff0c;关注公众号&#xff0c;公众号更新更快&#xff0c;同时也有更多学习资料和技术讨论群。 1、概述 当用户拥有多台设备时&#xff0c;可以通过跨设备剪贴板的功能&#xff0c…

2. MySQL数据库基础

一、数据库的操作 1. 显示当前的数据库 SHOW DATABASES;2. 创建数据库 语法&#xff1a; CREATE DATABASE [IF NOT EXISTS] db_name [create_specification...];//create_specification包括&#xff1a;[DEFAULT] CHARACTER SET charset_name[DEFAULT] COLLATE collation_n…

【题解】【记忆化递归】——Function

【题解】【记忆化递归】——Function Function题目描述输入格式输出格式输入输出样例输入 #1输出 #1 提示数据规模与约定 1.思路解析2.AC代码 Function 通往洛谷的传送门 题目描述 对于一个递归函数 w ( a , b , c ) w(a,b,c) w(a,b,c) 如果 a ≤ 0 a \le 0 a≤0 或 b ≤…

2025年广西高考报名流程图解(手机端)

广西 2025 年高考报名时间已经确定啦&#xff0c;从 2024 年 10 月 21 日开始&#xff0c;到 10 月 31 日 17:30 结束 &#x1f4bb;【报名路径】 有电脑端和手机端两种选择哦。 电脑端&#xff1a;登录 “广西招生考试院” 网站&#xff08;https://www.gxeea.cn&#xff0…

SQL数据库刷题sql_day34(移动平均值、累计求和)

描述 移动平均值 1.求不同产品 每个月以及截至当前月最近3个月的平均销售额 2.求不同产品截至当前月份的累计销售额 数据准备 mysql CREATE TABLE sales_monthly (product VARCHAR(20),ym VARCHAR(10),amount DECIMAL(10,2) );-- 插入测试数据 INSERT INTO sales_monthly (prod…

厨房老鼠数据集:掀起餐饮卫生监测的科技浪潮

厨房老鼠数据集&#xff1a;掀起餐饮卫生监测的科技浪潮 摘要&#xff1a;本文深入探讨了厨房老鼠数据集在餐饮行业卫生管理中的重要性及其相关技术应用。厨房老鼠数据集通过收集夜间厨房图像、老鼠标注信息以及环境数据&#xff0c;为深度学习模型提供了丰富的训练样本。基于…

目前我国网络安全人才市场状况

网络安全人才市场状况 本章以智联招聘多年来形成的丰富的招聘、求职信息大数据为基础&#xff0c;结合了奇安信集团 在网络安全领域多年来的专业研究经验&#xff0c;相关研究成果具有很强的代表性。对涉及安全人才 的全平台招聘需求与求职简历进行分析&#xff08;注&#xf…

Ajax(web笔记)

文章目录 1.Ajax的概念2.Ajax 的作用3.原生Ajax4.Axios4.1Axios的概念4.2Axios入门 1.Ajax的概念 AsynchronousJavaScriptAndXML&#xff0c;异步的JavaScript和XML 2.Ajax 的作用 数据交换:过Ajax可以给服务器发送请求&#xff0c;并获取服务器响应的数据。异步交互:可以在…

小猿口算辅助工具(nodejs版)

github 地址&#xff1a;https://github.com/pbstar/xyks-helper 实现原理 通过屏幕截图截取到题目区域的两个数字&#xff0c;然后通过 ocr 识别出数字&#xff0c;最后通过计算得出答案&#xff0c;并通过模拟鼠标绘制答案。 依赖插件 node-screenshots&#xff1a;屏幕截…

ai搜索工具免费的有那些?这几年搜索都发生了哪些变化?

前言这几年大家的搜索都发生了哪些变化&#xff1f; 要说疯狂的就属于AI工具了&#xff0c;以前搜索内容有广告自己只能眼巴巴的看着&#xff0c;现在不少人的搜索行为都有所变化&#xff0c;经过自己测试也给大家推荐一些自己使用的AI搜索工具毕竟免费。AI对传统搜索影响在传…