自然语言处理---Self Attention自注意力机制

Self-attention介绍

Self-attention是一种特殊的attention,是应用在transformer中最重要的结构之一。attention机制,它能够帮助找到子序列和全局的attention的关系,也就是找到权重值wi。Self-attention相对于attention的变化,其实就是寻找权重值的wi过程不同。

  • 为了能够产生输出的向量yi,self-attention其实是对所有的输入做了一个加权平均的操作,这个公式和上面的attention是一致的。
  • j代表整个序列的长度,并且j个权重的相加之和等于1。值得一提的是,这里的 wij并不是一个需要神经网络学习的参数,它是来源于xi和xj的之间的计算的结果(这里wij的计算发生了变化)。它们之间最简单的一种计算方式,就是使用点积的方式。
  • xi和xj是一对输入和输出。对于下一个输出的向量yi+1,有一个全新的输入序列和一个不同的权重值。

  • 这个点积的输出的取值范围在负无穷和正无穷之间,所以要使用一个softmax把它映射到[0,1] 之间,并且要确保它们对于整个序列而言的和为1。
  • 以上这些就是self-attention最基本的操作。

Self-attention和Attention使用方法

根据他们之间的重要区别,可以区分在不同任务中的使用方法:

  • 在神经网络中,通常来说会有输入层(input),应用激活函数后的输出层(output),在RNN当中会有状态(state)。如果attention (AT) 被应用在某一层的话,它更多的是被应用在输出或者是状态层上,而当使用self-attention(SA),这种注意力的机制更多的实在关注input上。
  • Attention (AT) 经常被应用在从编码器(encoder)转换到解码器(decoder)。比如说,解码器的神经元会接受一些AT从编码层生成的输入信息。在这种情况下,AT连接的是**两个不同的组件**(component),编码器和解码器。但是如果用**SA**,它就不是关注的两个组件,它只是在关注应用的**那一个组件**。那这里就不会去关注解码器了,就比如说在Bert中,使用的情况,就没有解码器。
  • SA可以在一个模型当中被多次的、独立的使用(比如说在Transformer中,使用了18次;在Bert当中使用12次)。但是,AT在一个模型当中经常只是被使用一次,并且起到连接两个组件的作用。
  • SA比较擅长在一个序列当中,寻找不同部分之间的关系。比如说,在词法分析的过程中,能够帮助去理解不同词之间的关系。AT却更擅长寻找两个序列之间的关系,比如说在翻译任务当中,原始的文本和翻译后的文本。这里也要注意,在翻译任务重,SA也很擅长,比如说Transformer。
  • AT可以连接两种不同的模态,比如说图片和文字。SA更多的是被应用在同一种模态上,但是如果一定要使用SA来做的话,也可以将不同的模态组合成一个序列,再使用SA。
  • 其实有时候大部分情况,SA这种结构更加的general,在很多任务作为降维、特征表示、特征交叉等功能尝试着应用,很多时候效果都不错。

Self-attetion实现步骤

  • 这里实现的注意力机制是现在比较流行的点积相乘的注意力机制
  • self-attention机制的实现步骤
    • 第一步: 准备输入
    • 第二步: 初始化参数
    • 第三步: 获取key,query和value
    • 第四步: 给input1计算attention score
    • 第五步: 计算softmax
    • 第六步: 给value乘上score
    • 第七步: 给value加权求和获取output1
    • 第八步: 重复步骤4-7,获取output2,output3

1. 准备输入

# 这里随机设置三个输入, 每个输入的维度是一个4维向量
import torch
x = [[1, 0, 1, 0], # Input 1[0, 2, 0, 2], # Input 2[1, 1, 1, 1]  # Input 3
]
x = torch.tensor(x, dtype=torch.float32)

2. 初始化参数

# 每一个输入都有三个表示,分别为key(橙黄色),query(红色),value(紫色)。
# 每一个表示,希望是一个3维的向量。由于输入是4维,所以参数矩阵为 4*3 维。

# 为了能够获取这些表示,每一个输入(绿色)要和key,query和value相乘

# 在例子中,使用如下的方式初始化这些参数。
w_key = [[0, 0, 1],[1, 1, 0],[0, 1, 0],[1, 1, 0]
]
w_query = [[1, 0, 1],[1, 0, 0],[0, 0, 1],[0, 1, 1]
]
w_value = [[0, 2, 0],[0, 3, 0],[1, 0, 3],[1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)print("w_key: \n", w_key)
print("w_query: \n", w_query)
print("w_value: \n", w_value)

3. 获取key,query和value

# 使用向量化获取keys的值
                    [0, 0, 1]
[1, 0, 1, 0]    [1, 1, 0]    [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]    [1, 1, 0]    [2, 3, 1]

# 使用向量化获取values的值
                    [0, 2, 0]
[1, 0, 1, 0]    [0, 3, 0]    [1, 2, 3] 
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]    [1, 1, 0]    [2, 6, 3]

# 使用向量化获取querys的值
                    [1, 0, 1]
[1, 0, 1, 0]    [1, 0, 0]    [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]    [0, 1, 1]    [2, 1, 3]

# 将query key  value分别进行计算
keys = x @ w_key
querys = x @ w_query
values = x @ w_value
print("Keys: \n", keys)
print("Querys: \n", querys)
print("Values: \n", values)

4. 给input1计算attention score

# 获取input1的attention score,使用点乘来处理所有的key和query,包括自己的key和value。
# 这样就能够得到3个key的表示(因为有3个输入),就获得了3个attention score(蓝色)
                [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
                [1, 0, 1]

# 注意: 这里只用input1举例。其他的输入的query和input1做相同的操作.

attn_scores = querys @ keys.T
print(attn_scores)

5. 计算softmax

from torch.nn.functional import softmaxattn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
attn_scores_softmax = [[0.0, 0.5, 0.5],[0.0, 1.0, 0.0],[0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)softmax([2, 4, 4]) = [0.0, 0.5, 0.5]

6. 给value乘上score

使用经过softmax后的attention score乘以它对应的value值(紫色),这样就得到了3个weighted values(黄色)

1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]

weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)

7. 给value加权求和获取output1

把所有的weighted values(黄色)进行element-wise的相加。

   [0.0, 0.0, 0.0]

+ [1.0, 4.0, 0.0]

+ [1.0, 3.0, 1.5]

------------------------

= [2.0, 7.0, 1.5]

得到结果向量[2.0, 7.0, 1.5](深绿色)就是ouput1的和其他key交互的query representation

8. 重复步骤4-7,获取output2,output3

outputs = weighted_values.sum(dim=0)
print(outputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/166449.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用vscode搭建虚拟机

首先vscode插件安装 名称: Remote - SSH ID: ms-vscode-remote.remote-ssh 说明: Open any folder on a remote machine using SSH and take advantage of VS Codes full feature set. 版本: 0.51.0 VS Marketplace 链接: https://marketplace.visualstudio.com/items?it…

黑豹程序员-架构师学习路线图-百科:MVC的演变终点SpringMVC

MVC发展史 在我们开发小型项目时,我们代码是混杂在一起的,术语称为紧耦合。 如最终写ASP、PHP。里面既包括服务器端代码,数据库操作的代码,又包括前端页面代码、HTML展现的代码、CSS美化的代码、JS交互的代码。可以看到早期编程就…

基于SSM的快递管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

proteus中仿真arduino的水位测试传感器

一、原理介绍 我们这里使用的水位传感器,只能说是一个小实验用途的水位传感器。我们首先上图 如上图所示,线没有连接,传感器由许5对裸露在外的铜线片作为传感部分,当浸入水中时这些铜线片会被水桥接。 这些被水连接起来的铜线&a…

【NPM】vuex 数据持久化库 vuex-persistedstate

在 GitHub 上找到:vuex-persistedstate。 安装 npm install --save vuex-persistedstate使用 import { createStore } from "vuex"; import createPersistedState from "vuex-persistedstate";const store createStore({// ...plugins: [cr…

android studio打开flutter项目报红

一、android studio打开flutter项目报红,如下图: 二、解决方法: 2.1 在这个build.gradle添加以下代码,如图: 2.2 在build.gradle最顶部添加如下代码: def localProperties new Properties() def localPr…

基于指数分布优化的BP神经网络(分类应用) - 附代码

基于指数分布优化的BP神经网络(分类应用) - 附代码 文章目录 基于指数分布优化的BP神经网络(分类应用) - 附代码1.鸢尾花iris数据介绍2.数据集整理3.指数分布优化BP神经网络3.1 BP神经网络参数设置3.2 指数分布算法应用 4.测试结果…

STM32不使用 cubeMX实现外部中断

这篇文章将介绍如何不使用 cubeMX完成外部中断的配置和实现。 文章目录 前言一、文件加入工程二、代码解析exti.cexti.hmain.c 注意:总结 前言 实验开发板:STM32F103C8T6。所需软件:keil5 , cubeMX 。实验目的:如何不…

AdaBoost:增强机器学习的力量

一、介绍 机器学习已成为现代技术的基石,为从推荐系统到自动驾驶汽车的一切提供动力。在众多机器学习算法中,AdaBoost(自适应增强的缩写)作为一种强大的集成方法脱颖而出,为该领域的成功做出了重大贡献。AdaBoost 是一…

面试题:谈谈过滤器和拦截器的区别?

文章目录 一、拦截器和过滤器的区别二、拦截器和过滤器的代码实现1、拦截器2、过滤器 三、总结1、什么是Filter及其作用介绍2、Filter API介绍3、Filter链与Filter生命周期 四、拦截器五、过滤器和拦截器的区别 一、拦截器和过滤器的区别 1、拦截器(Interceptor)只对action请求…

【鸿蒙软件开发】ArkTS常见组件之单选框Radio和切换按钮Toggle

文章目录 前言一、Radio单选框1.1 创建单选框1.2 添加Radio事件1.3 场景示例二、切换按钮Toggle2.1 创建切换按钮2.2 创建有子组件的Toggle2.3 自定义样式selectedColor属性switchPointColor属性 2.4 添加事件2.5 示例代码 总结 前言 Radio是单选框组件,通常用于提…

SpringCloud之gateway基本使用解读

目录 基本介绍 概述 API网关介绍 路由(Route) 断言(Predicate) 过滤器(Filter) 简单JAVA代码实战 实战架构 teacherservice服务 gateway服务 测试 断言工厂 过滤器工厂 全局过滤器 &#xf…

前端导出数据到Excel(Excel.js导出数据)

库&#xff1a;Excel.js&#xff08;版本4.3.0&#xff09; 和 FileSaver&#xff08;版本2.0.5&#xff09; CDN地址&#xff1a; <script src"https://cdn.bootcdn.net/ajax/libs/exceljs/4.3.0/exceljs.min.js"></script> <script src"http…

什么是SpringMVC?简单好理解!

1、SpringMVC是什么&#xff1f; SpringMVC是一个基于Java的实现了MVC设计模式的请求驱动类型的轻量级web框架&#xff0c;通过把Model&#xff0c;View&#xff0c;Controller分离&#xff0c;将web层进行职责解耦&#xff0c;把复杂的web应用分成逻辑清晰的几部分。简化开发&…

UE5场景逐渐变亮问题

1、显示 -- 关闭眼部适应 2、项目设置 -- 关闭自动曝光 参考&#xff1a; 虚幻5/UE5 场景亮度逐渐变亮完美解决方法 - 哔哩哔哩

专业修图软件 Affinity Photo 2 mac中文版编辑功能

Affinity Photo for Mac是应用在MacOS上的专业修图软件&#xff0c;支持多种文件格式&#xff0c;包括psD、PDF、SVG、Eps、TIFF、JPEG等。 Affinity Photo mac提供了许多高级图像编辑功能&#xff0c;如无限制的图层、非破坏性操作、高级的选择工具、高级的调整层、HDR合成、全…

使用 PyAudio、语音识别、pyttsx3 和 SerpApi 构建简单的基于 CLI 的语音助手

德米特里祖布☀️ 一、介绍 正如您从标题中看到的&#xff0c;这是一个演示项目&#xff0c;显示了一个非常基本的语音助手脚本&#xff0c;可以根据 Google 搜索结果在终端中回答您的问题。 您可以在 GitHub 存储库中找到完整代码&#xff1a;dimitryzub/serpapi-demo-project…

RT-Thread学习笔记(四):RT-Thread Studio工具使用

RT-Thread Studio工具使用 官网详细资料实用操作1. 查看 RT-Thread RTOS API 文档2.打开已创建的工程3.添加头文件路径4. 如何设置生成hex文件5.新建工程 官网详细资料 RT-Thread Studio 用户手册 实用操作 1. 查看 RT-Thread RTOS API 文档 2.打开已创建的工程 如果打开项目…

Breach 1.0 靶机

Breach 1.0 环境配置 设置 VMware 上的仅主机模式网卡&#xff0c;勾选 DHCP 自动分配 IP&#xff0c;将子网改为 192.168.110.0/24 将靶机和 kali 连接到仅主机网卡 信息搜集 存活检测 详细扫描 后台网页扫描 网页信息搜集 Initech被入侵&#xff0c;董事会投票决定引入…

pytorch nn.Embedding 读取gensim训练好的词/字向量(有例子)

最近在跑深度学习模型&#xff0c;发现Embedding随机性太强导致模型结果有出入&#xff0c;因此考虑固定初始随机向量&#xff0c;既提前训练好词/字向量&#xff0c;不多说上代码&#xff01;&#xff01; 1、利用gensim训练字向量&#xff08;词向量自行修改&#xff09; #…