MultiHeadAttention在Tensorflow中的实现原理


前言

通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。


一、MultiHeadAttention的本质内涵

1.Self_Atention机制

MultiHeadAttention是Self_Atention的多头堆嵌,有必要对Self_Atention机制进行一次深入浅出的理解,这也是MultiHeadAttention的核心所在。

Self_Attention并不直接使用输入向量,而是先将其进行映射,使得输入向量在每个位置上产生一个query和context,context充当字典。在context的每个位置都提供一个key和value向量。

query:尝试去获取某类信息的序列。

context:包含key序列和value序列,是query感兴趣的内容。

最终输出的形状将与query序列相同。

一个常见的类比是,这种操作就像字典查询。一个模糊的、可区分的、矢量的字典查询。

如下是一个普通的 python 字典类型数据,有 3 个键和 3 个值,并被传递给一个query——"What color is it ?"。这个query会与key="color"最契合,最终得到查询结果value="blue"

query是你要尝试去找的东西。key表示字典里有哪些信息,而value就是这些信息。当你在正则字典中查找一个query时,字典会找到匹配的key,并返回其相关的value。这个查询要么有一个匹配的键,要么没有。你可以想象一个模糊的字典,其中的键不一定要完全匹配。如果你在上面的字典中查找 query—"What species is it ?",也许你希望它返回 key="type",value="pickup",因为那是与query最匹配的key和value。

注意力层就像这样做了一个模糊查找,但它不仅仅是在寻找最好的key,而是根据query与每个key的匹配程度来组合这些value。

那是如何做到这一点的呢?在注意力层中,query、key和value都是向量。注意力层不是简单地做哈希查找,而是结合query和key向量来确定它们的匹配程度——计算query和key的向量点积,再将所有匹配程度经过Softmax映射完后,即得到 "注意力得分"。最终该层返回所有value的加权平均值,以 "注意力分数 "为权重。

对于一段具体的文本来说,每一个字都会引发一个疑问query,并提供一个关键值key和一个目标内容value。每个query都会去寻找感兴趣的key,并按注意力分数提取并组合value,

图中越粗的红线对应的attention权重更大,query与key的紧密程度也越近。attention权重如此分布也是很符合情理的,要想回答query =“他是谁?”,我们很大可能会在“是”后面寻找答案,因为“爱人”提供的信息最多,所以它俩的attention权重最大。

总的来说,Self_Attention模拟的是一个符合人脑思维逻辑的研究过程。每当遇到一些新的信息,我们总会产生一定量的疑问(query),为了解决疑问,我们需要在信息中捕捉关键字(key),进而凝练出该关键字中所蕴涵的答案(value)。特定的疑问(query)需要联系特定的关键字(key),进而得出最终答案,这个最终答案往往是折合了不同value而得来的。

2.MultiHead_Atention机制

在不同情景中,字引发的query是不同的,例如,

“他是男的,已婚。”

query可以是”他的性别是什么?”,或者”他结婚了吗?”。不同的query会产生不同的注意力分数。单一的Self_attention无法捕捉多层面query和key之间的依赖关系,因为它只进行一次attention的分配。意在解决此类局限性,MultiHead_Atention会计算多次Self_attention。

利用MultiHead_Atention机制,可以为每一个输入学习到一个信息量丰富的向量表示。

二、使MultiHeadAttention在TensorFlow中的代码实现

1.参数说明

TensorFlow中是用tf.keras.layers.MultiHeadAttention()实现的。它的参数分为两类,一种参数为初始化参数,存在于__init__方法中;另一种为调用参数,存在于call方法中。

主要的初始化参数:

num_heads:Self_Attention的层数

key_dim:query和key多头映射层的输出shape在axis=-1上的长度。因为后续需要计算query和key的点积,所以需要保证query和key在最后一个轴上的长度相等。

value_dim:value多头映射层的输出shape在axis=-1上的长度。如果不指定,则默认等于key_dim

output_shape:  指定输入经过整个MultiHeadAttention层后的输出shape,默认与进入query多头映射层的输入shape相同

主要的call方法参数:

'''  B即Batch_size,每一批中的样本数;

    T是query的个数,即一段序列产生的疑问个数;S是value和key的个数,即一段序列产生的关键字和关键信息的个数,序列产生的key和value是成对出现的,所以value映射层 和key映射层的输入张量在axis=1处的长度S相同。T和S是可以随意指定的,只需在样本集进入Embedding层之前,先通过一个dense层进行T和S的指定(T和S等于各自dense层中的神经元个数)。例如,文本集shape=(B, S),经过一个具有T个神经元的Dense层→shape=(B, T),再经Embedding层→shape=(B, T, dim),得到query映射层的输入张量。当然,如果不愿如此麻烦,可直接将经Embedding层得到shape=(B, S, dim)的张量作为query映射层的输入;

    dim通常是Embedding向量的长度(每个字对应一个Embedding向量)'''

query:输入query多头映射层且shape为(B, T, dim)的张量

value:输入value多头映射层且shape为(B, S, dim)的张量

key:输入key多头映射层且shape为(B, S, dim)的张量,如果未指定,则key=value

use_causal_mask:布尔值,是否开启causal_mask(因果掩码)机制

2.整体结构

tf.keras.layers.MultiHeadAttention类中call()方法的逻辑过程就是MultiHeadAttention的前向传播过程,我将其提炼成以下三部分,

        ''' 多头映射层 '''query = self._query_dense(query)key = self._key_dense(key)value = self._value_dense(value)''' 注意力层 '''attention_output, attention_scores = self._compute_attention(query, key, value, attention_mask, training)''' 输出层 '''attention_output = self._output_dense(attention_output)

3.多头映射层

由query多头映射层—query_dense,value多头映射层—value_dense,key多头映射层—key_dense组成。

每个映射层执行的张量运算是一样的,张量运算逻辑为,

                                                   ' abc , cde -> abde '               



该层的训练参数总数为,

4.注意力层

计算query与key之间的内积,张量运算逻辑为,

                                              ' aecd, abcd -> acbe '

内积能够反映向量之间的相关程度,内积结果越大则相关性越大,联系也越紧密。得到query和key的内积后,为了得到attention分数,需要将内积结果进行softmax映射。

sttention_scores张量可视作一个B行num_heads列的矩阵,其矩阵中的元素均是T行S列的注意力分数矩阵。当输入是大序列(比如音频序列)时,TransFormer需要维护的注意力分数矩阵将呈n^{2}曲线式增长,这种庞大的数据量将会对TransFormer训练和推理的效率和速度产生严重的影响,在内存上的要求也会成n^{2}曲线式增长。


最后利用attention分数对value进行加权叠加,张量运算逻辑为,

                                                        'acbe,aecd->abcd' 

  



注意:如果指定use_causal_mask=True引入Causal_Mask(因果掩码)机制,则在softmax映射时,会传入一个左下三角为True右上三角为False的,布尔类型的,且与attention_scores.shape相同掩码张量,此时掩码张量中为False的对应位置(对应attention内积张量)将会被softmax忽略。如此一来就会导致每个query只会与当前及其以前的key进行内积,并不会考虑未来的key。进而导致在每个query处产生的新value只会是当前value与过往value在sttention分数上的加权叠加。这样的结构是因果的,符合在预测中结果会对输入产生影响的因果逻辑。因果掩码会在Decoder中使用。


注意力层无可训练的参数。

5.输出映射层

属于MultiHeadAttention的最后一层,负责将注意力层得到的value在sttention分数上的加权叠加后的张量进行输出映射。张量运算逻辑为,

                                                       ' abcd, cde -> abe '



该层训练参数总共为,


验证

import tensorflow as tflayer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[9, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(query=target, value=source,return_attention_scores=True)''' 手动计算训练参数总数 '''
sum = 16*2*2*3+2*2*3+2*2*16+16
print(f'手动计算的训练参数总数为 : {sum}')
print(f'训练参数总共为 : {layer.count_params()}')
print(f'输出shape为 : {output_tensor.shape}')
print(f'注意力分数shape为 : {weights.shape}')手动计算的训练参数总数为 : 284
训练参数总共为 : 284
输出shape为 : (None, 9, 16)
注意力分数shape为 : (None, 2, 9, 4)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/315086.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AJAX——案例

1.商品分类 需求&#xff1a;尽可能同时展示所有商品分类到页面上 步骤&#xff1a; 获取所有的一级分类数据遍历id&#xff0c;创建获取二级分类请求合并所有二级分类Promise对象等待同时成功后&#xff0c;渲染页面 index.html代码 <!DOCTYPE html> <html lang&qu…

ssh 文件传输:你应该掌握的几种命令行工具

这篇文章主要分享一下我使用过的 ssh 传输文件的进阶路程&#xff0c;从 scp -> lrzsz -> trzsz&#xff0c;希望能给你带来一些帮助&#xff5e; scp scp 命令可以用于在 linux 系统之间复制文件&#xff0c;具体的语法可以参考下图 其实使用起来也还比较方便&#x…

【Docker】Docker 实践(三):使用 Dockerfile 文件构建镜像

Docker 实践&#xff08;三&#xff09;&#xff1a;使用 Dockerfile 文件构建镜像 1.使用 Dockerfile 文件构建镜像2.Dockerfile 文件详解 1.使用 Dockerfile 文件构建镜像 Dockerfile 是一个文本文件&#xff0c;其中包含了一条条的指令&#xff0c;每一条指令都用于构建镜像…

智慧码头港口:施工作业安全生产AI视频监管与风险预警平台方案

一、建设思路 随着全球贸易的快速发展&#xff0c;港口作为连接海洋与内陆的关键节点&#xff0c;其运营效率和安全性越来越受到人们的关注。为了提升港口的运营效率和安全性&#xff0c;智慧港口视频智能监控系统的建设显得尤为重要。 1&#xff09;系统架构设计 系统应该采…

针对icon报错

针对上篇文章生成图标链接中图标报错 C# winfrom应用程序添加图标-CSDN博客 问题&#xff1a;参数“picture”必须是可用作Icon的参数 原因&#xff1a;生成的ico图标类型不匹配 解决方法&#xff1a; 更改导出的ico类型

下载学浪视频,小浪助手一键搞定

小浪助手可以一键获取课程&#xff0c;一键根据课程获取视频列表&#xff0c;而且内置了2大下载器&#xff0c;N_m3u8和逍遥一仙下载器 小浪助手我已经打包好了&#xff0c;有需要的自己取一下 学浪下载工具链接&#xff1a;https://pan.baidu.com/s/1_Sg-EGGXKc4bMW-NPqUqvg…

第55篇:创建Nios II工程之Hello_World<一>

Q&#xff1a;本期我们开始介绍创建Platform Designer系统&#xff0c;并设计基于Nios II Processor的Hello_world工程。 A&#xff1a;设计流程和实验原理&#xff1a;需要用到的IP组件有Clock Source、Nios II Processor、On-Chip Memory、JTAG UART和System ID外设。Nios I…

Maven多模块快速升级超好用Idea插件-MPVP

功能&#xff1a;多模块maven项目快速升级指定版本插件&#xff0c;并提供预览和相关升级模块日志能力。 可快速进行版本升级&#xff0c;进行部署到Maven仓库。 安装&#xff1a; 可在idea插件中心进行安装 / 下载资源拖动安装 MPVP(Maven) - IntelliJ IDEs Plugin | Marke…

构建安全高效的前端权限控制系统

✨✨谢谢大家捧场&#xff0c;祝屏幕前的小伙伴们每天都有好运相伴左右&#xff0c;一定要天天开心哦&#xff01;✨✨ &#x1f388;&#x1f388;作者主页&#xff1a; 喔的嘛呀&#x1f388;&#x1f388; ✨✨ 帅哥美女们&#xff0c;我们共同加油&#xff01;一起进步&am…

数据库变更时,OceanBase如何自动生成回滚 SQL

背景 在开发中&#xff0c;数据的变更与维护工作一般较频繁。当我们执行数据库的DML操作时&#xff0c;必须谨慎考虑变更对数据可能产生的后果&#xff0c;以及变更是否能够顺利执行。若出现意外数据丢失、操作失误或语法错误等情况&#xff0c;我们必须迅速将数据库恢复到变更…

Bayes判别示例数据:鸢尾花数据集

使用Bayes判别的R语言实例通常涉及使用朴素贝叶斯分类器。朴素贝叶斯分类器是一种简单的概率分类器&#xff0c;基于贝叶斯定理和特征之间的独立性假设。在R中&#xff0c;我们可以使用e1071包中的naiveBayes函数来实现这一算法。下面&#xff0c;我将通过一个简单的示例展示如…

npm、yarn与pnpm详解

&#x1f525; npm、yarn与pnpm详解 &#x1f516; 一、npm &#x1f50d; 简介&#xff1a; npm是随Node.js一起安装的官方包管理工具&#xff0c;它为开发者搭建了一个庞大的资源库&#xff0c;允许他们在这个平台上搜索、安装和管理项目所必需的各种代码库或模块。 &#…

Intelij Idea Push失败,出现git Authentication failed(验证失败)

目录 1、出现问题的原因 2、解决之法 1、出现问题的原因 能出现这种问题&#xff0c;最主要的原因是链接对上了&#xff0c;但用户验证失败了&#xff0c;即登录失败。 因为服务器转移或者换了git项目链接&#xff0c;导致你忘记了用户名密码&#xff0c;随意输入之后&…

求三个字符数组最大者(C语言)

一、N-S流程图&#xff1b; 二、运行结果&#xff1b; 三、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h> # include <string.h>int main() {//初始化变量值&#xff1b;int i 0;char str[3][20];char string[20];//循环输入3个字符…

AIGC - SD(中英文本生成图片) + PaddleHub/HuggingFace + stable-diffusion-webui

功能 stable-diffusion(文本生成图片)webui-win搭建&#xff08;开启api界面汉化&#xff09;PaddleHubHuggingFace: SD2&#xff0c;中文-alibaba/EasyNLP stable-diffusion-webui 下载与安装 环境相关下载 python&#xff08;文档推荐&#xff1a;Install Python 3.10.6 …

Ubuntu中的 Everything 搜索软件 ==> fsearch

本文所使用的 Ubuntu 系统版本是 Ubuntu 22.04 ! 在 Windows 中&#xff0c;我经常使用 Everything 来进行文件搜索&#xff0c;搜索效率比 Windows 自带的高出千百倍。 那么在 Ubuntu 系统中&#xff0c;有没有类似的软件呢&#xff1f;那必须有&#xff0c;它就是 FSearch 。…

uniapp app权限说明弹框2024.4.23更新

华为上架被拒绝 用uni-app开发的app&#xff0c;上架华为被拒&#xff0c;问题如下&#xff1a; 您的应用在运行时&#xff0c;未见向用户告知权限申请的目的&#xff0c;向用户索取&#xff08;电话、相机、存储&#xff09;等权限&#xff0c;不符合华为应用市场审核标准。…

C#基础:WPF中常见控件的布局基础

一、用ViewBox实现放缩控件不变 二、布局代码 <Window x:Class"WpfApp1.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"Title"MainWindow"…

探索矿业数字化平台:实现智能化采矿与管理

随着信息技术的迅猛发展&#xff0c;矿业领域也在逐步实现数字化转型。数字化平台的出现为矿业企业带来了更高效、更智能的采矿与管理方式。本文将探讨矿业数字化平台的意义、特点以及未来发展方向。 ### 1. 数字化平台的意义 传统的矿业生产和管理方式存在诸多问题&#xff…

浏览器渲染机制:重排(Reflow)与重绘(Repaint)以及Vue优化策略

浏览器渲染机制是一个复杂但有序的过程&#xff0c;其目的是将HTML、CSS和JavaScript代码转化为用户可以看到和交互的视觉界面。重排&#xff08;Reflow&#xff09;与重绘&#xff08;Repaint&#xff09;是浏览器渲染过程中对页面元素进行更新的两个重要步骤&#xff0c;理解…