注意力机制(四)(多头注意力机制)

​🌈 个人主页十二月的猫-CSDN博客
🔥 系列专栏 🏀《深度学习基础知识》

      相关专栏: 《机器学习基础知识》

                         🏐《机器学习项目实战》
                         🥎《深度学习项目实战(pytorch)》

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 

目录

回顾

注意力机制与RNN、LSTM的对比 

总论

RNN

LSTM

注意力机制

多头注意力机制 

核心思想介绍

如何运用多头注意力机制

1、定义多组参数矩阵W(一般是八组),生成多组Q、K、V

 2、利用多组参数分别训练得到多组结果Z(上下文向量)

3、将多组输出拼接后乘以矩阵W0降低维度

 多头流程图

试图解释

Pytorch代码实现

代码解释 

两种实现思想对比

总结


回顾

在上一篇注意力机制(三)(不同注意力机制对比)-CSDN博客,重点讲了针对QKV来源不同制造的注意力机制的一些变体,包括交叉注意力、自注意力等。这里再对注意力机制理解中的核心要点进行归纳整理

1、注意力机制规定的是对QKV的处理,并不指定QKV的来源

2、注意力机制和RNN、LSTM本身是同级的,都可以用于独立解决时间序列的问题

3、注意力机制相比于RNN、LSTM来说能够学习句子内部的句法特征和语义特征

4、注意力机制能够进行并行运算,解决了RNN、LSTM串行运算的问题。但是也存在注意力机制运算量非常大的问题

5、注意力机制解决了RNN以及LSTM中有由于梯度消失梯度爆炸造成的长期依赖问题

注意力机制与RNN、LSTM的对比 

总论

RNN(递归神经网络):它能够处理序列数据,因为它具有循环的结构,可以保留之前的信息。这种结构模拟了人类思考时的持续性,即我们对当前事物的理解是建立在之前信息的基础上的。然而,传统的RNN在处理长距离依赖时会遇到困难,因为随着时间的推移,梯度可能会消失或爆炸,导致网络难以学习和记住长期的信息。

LSTM(长短时记忆网络):它是RNN的一种改进型,专门设计来解决长期依赖问题。LSTM通过引入三个门(遗忘门、输入门、输出门)的结构来控制信息的流动,从而有效地保存长期的信息。这种结构使得LSTM能够更好地查询较长一段时间内的信息,因为它可以通过“门”结构来决定哪些信息需要被记住或遗忘。

注意力机制:它是一种允许模型在处理序列时动态地关注不同部分信息的方法。注意力机制可以与RNN和LSTM结合使用,也可以独立使用。它的优点是可以提高模型对序列中重要信息的敏感度,从而提高模型的性能。注意力机制通过计算注意力权重来分配对不同时间步的关注,这使得模型在每个时间步都能够考虑到整个序列的信息。

RNN

        由于RNN在反向传播过程中涉及到矩阵的连乘,这可能导致梯度指数级减小(梯度消失)或增大(梯度爆炸)。梯度消失会使得网络难以学习和传递长期的依赖关系,因为梯度变得太小,以至于反向传播时权重更新几乎停滞(遗忘,不再考虑这个信息)。相反,梯度爆炸会导致梯度过大,使得网络训练不稳定甚至发散

LSTM

        LSTM就是应对RNN长期依赖问题而产生的。LSTM利用遗忘门、输出输入门以及特殊结构——记忆单元使得长期的信息能够被模型记忆并传递下来从而一定程度上解决了RNN因梯度消失、爆炸出现的长期依赖问题

        但是LSTM在序列很长时仍然会出现梯度消失梯度爆炸的问题,并且LSTM并不能实现并行运算。同时,LSTM在应对并不能很好的捕捉一个序列后面的信息(因为因果卷积的问题),所有导致LSTM并不能很好的理解句子的语法和句法特征

注意力机制

        注意力机制相比于前两个模型的特点在于:1、能够进行并行运算;2、能够完美解决长期依赖问题;3、由于对句子有全面的相似度计算,能够更好理解句子的句法特征和语义特征

上图体现了注意力机制对句子句法特征的理解

 上图体现了注意力机制对句子语法特征的理解

多头注意力机制 

核心思想介绍

前文,我们介绍了自注意力机制:自注意力的QKV是同源的。同源的好处就是更容易发现序列内部的信息,但是也存在一些可以改进的地方。

例如:对于一个待分析的序列矩阵,它存在许多方面的特征。此时我们要用一个参数矩阵Wq、Wk去分析并学习出序列中的这么多特征。由于参数矩阵的维度是有限的,所以一次性学习多特征的信息必然会造成信息学习的模糊性,所以作者又提出了多头注意力机制

下图为多头注意力机制模型图:

多头注意力机制在以下两个方面提升注意力机制的性能:

  • 它为注意力机制提供了多个投射子空间的可能。它利用多头机制提供了多组的参数矩阵,每组参数矩阵能够通过线性变化将词向量X放入不同的向量空间,从而反映出词向量X的不同特征。多组参数矩阵映射不同向量空间,再将不同向量空间的结果进行整合,如此比单组参数矩阵表示向量特征更加准确
  • 它拓展了模型关注不同位置的能力。多头机制不仅在词向量维度上能够挖掘更多词向量的特征,在词数上也能够同时关注更多的词信息

通过多头注意力机制,我们会为每一个都单独配置QKV的权重矩阵,从而在模型训练中产生不同的QKV矩阵。对于每个头来说,其核心思想和训练方式和自注意力机制是相似的

如何运用多头注意力机制

1、定义多组参数矩阵W(一般是八组),生成多组Q、K、V

多头注意力机制的每一头的处理方式和自注意力机制是相同的,也就是利用输入向量X分别乘上从而得到对应的q、k、v。然后每个头的参数矩阵会在不同初始值的情况下,各自训练自己的参数,最后分别生成不同的Q、K、V(初始值不同最后学习的结果也不同,可以参考梯度下降中的局部最优理解)

下图举了两个头的注意力机制的示意图:

 2、利用多组参数分别训练得到多组结果Z(上下文向量)

 将多组训练得到的参数与V经过mulmat融合得到多组Z。此时的上下文向量Z不仅包含原始的信息也包括对文本上下文注意力的信息,并且这个注意力信息利用多头考虑了多个维度的特征信息

3、将多组输出拼接后乘以矩阵W0降低维度

通过第二步得到多组的Z,为了全面的利用所有Z的信息,我们将Zconcat(拼接在一起),这将得到一个非常长的向量矩阵。由于输出结果Z和输入结果X的向量矩阵维度应该要相同,所以我们利用矩阵W0对结果进行变化降维,得到最终结果

 多头流程图

借用其他大佬翻译的流程图版本

试图解释

深度学习模型都是黑盒子模型,所以没有一个很严谨的解释。这里,我也只能给一个非常模糊且不透彻的解释,希望能帮助大家的理解

假设我们有一句话“The animal didn’t cross the street because it was too tired”

  • 图中绿线和橙线表示两个不同的头
  • 可以看到绿线重点关注的是tired单词,橙线重点关注animal单词。这表明it在高维度上某些特征和animal相似,另外一些特征和tired相似
  • 经过注意力机制调整,it中将包含tire和animal两个单词的信息。模型在分析时,对于it单词也将重点关心tired和animal两个单词

说明上面这句翻译的英语句子中it和animal以及tired的关系度相对较大。我们自己分析这个句子时结果也是这样,因为it就指代animal,同时全句子的重点也就是在animal很tired上

一旦注意力头更多之后,整个模型的解释会变得更难,因此我们不再展开

Pytorch代码实现

import torch
import torch.nn.functional as F
import torch.nn as nn
import mathdef self_attention(query, key, value, dropout=None, mask=None):"""前置参数(自注意力):输入矩阵X形状为(batch_size, seq_len, d_model)Q = torch.matmul(X, W_Q)K = torch.matmul(X, W_K)V = torch.matmul(X, W_V)自注意力计算::param query: Q:param key: K:param value: V:param dropout: drop比率:param mask: 是否mask:return: 经自注意力机制计算后的值"""# d_k指降维后待查询的词向量维度d_k = query.size(-1)  # 防止softmax未来求梯度消失时的d_k# Q,K相似度计算公式:\frac{Q^TK}{\sqrt{d_k}},score的维度就是词数*词数(每两个词语间的相似度)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # Q,K相似度计算# 判断是否要mask,注:mask的操作在QK之后,softmax之前if mask is not None:"""scores.masked_fill默认是按照传入的mask条件中为1的元素所在的索引,在scores中相同的的索引处替换为value,替换值为-1e9,即-(10^9)"""mask.cuda()  # 将mask放入GPU运算#score此时就是一个torch.tensor对象,可以直接用masked_fill函数scores = scores.masked_fill(mask == 0, -1e9)self_attn_softmax = F.softmax(scores, dim=-1)  # 进行softmax# 判断是否要对相似概率分布进行dropout操作if dropout is not None:self_attn_softmax = dropout(self_attn_softmax)# 注意:返回经自注意力计算后的值,以及进行softmax后的相似度(即相似概率分布)# 词数*词数 * 词数*词向量维度 = 全新的 词数*词向量维度。如果是多头的:头数*词数*词向量维度return torch.matmul(self_attn_softmax, value), self_attn_softmaxclass MultiHeadAttention(nn.Module):  # 继承nn.module类"""多头注意力计算"""def __init__(self, head, d_model, dropout=0.1):""":param head: 头数:param d_model: 词向量的维度,必须是head的整数倍:param dropout: drop比率"""super(MultiHeadAttention, self).__init__() # 先初始化父类的属性(子类要用父类的属性和方法)assert (d_model % head == 0)  # 确保词向量维度是头数的整数倍self.d_k = d_model // head  # 被拆分为多头后的某一头词向量的维度,和自注意力降维后维度是相同的self.head = headself.d_model = d_model"""由于多头注意力机制是针对多组Q、K、V,因此有了下面这四行代码,具体作用是,针对未来每一次输入的Q、K、V,都给予参数进行构建其中linear_out是针对多头汇总时给予的参数"""self.linear_query = nn.Linear(d_model, d_model)  # 进行一个普通的全连接层变化,但不修改维度self.linear_key = nn.Linear(d_model, d_model)self.linear_value = nn.Linear(d_model, d_model)self.linear_out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(p=dropout)self.attn_softmax = None  # attn_softmax是能量分数, 即句子中某一个词与所有词的相关性分数, softmax(QK^T)def forward(self, query, key, value, mask=None):if mask is not None:"""多头注意力机制的线性变换层是4维,是把query[batch, frame_num, d_model]变成[batch, -1, head, d_k]再1,2维交换变成[batch, head, -1, d_k], 所以mask要在第二维(head维)添加一维,与后面的self_attention计算维度一样具体点将,就是:因为mask的作用是未来传入self_attention这个函数的时候,作为masked_fill需要mask哪些信息的依据针对多head的数据,Q、K、V的形状维度中,只有head是通过view计算出来的,是多余的,为了保证mask和view变换之后的Q、K、V的形状一直,mask就得在head这个维度添加一个维度出来,进而做到对正确信息的mask"""mask = mask.unsqueeze(1)n_batch = query.size(0)  # batch_size大小,假设query的维度是:[10, 32, 512],其中10是batch_size的大小"""下列三行代码都在做类似的事情,对Q、K、V三个矩阵做处理其中view函数是对Linear层的输出做一个形状的重构,其中-1是自适应(自主计算)这里本质用的是对词向量的维度进行了拆分,将不同维度放入不同自注意力模型去训练transopose(1,2)是对前形状的两个维度(索引从0开始)做一个交换,这里处理后的quary(batch,head,词数,词维度)假设Linear成的输出维度是:[10, 32, 512],其中10是batch_size的大小注:这里解释了为什么d_model // head == d_k,如若不是,则view函数做形状重构的时候会出现异常"""query = self.linear_query(query).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64],head=8key = self.linear_key(key).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)   # [b, 8, 32, 64],head=8value = self.linear_value(value).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64],head=8# x是通过自注意力机制计算出来的值, self.attn_softmax是相似概率分布x, self.attn_softmax = self_attention(query, key, value, dropout=self.dropout, mask=mask)"""首先,交换“head数”和“词数”,这两个维度,结果为(batch, 词数, head数, d_model/head数)对应代码为:`x.transpose(1, 2).contiguous()`然后将“head数”和“d_model/head数”这两个维度合并,结果为(batch, 词数,d_model)contiguous()是重新开辟一块内存后存储x,然后才可以使用.view方法,否则直接使用.view方法会报错"""x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.head * self.d_k)return self.linear_out(x)

代码解释 

代码带有注释细节方面就不在这里解释了,这里重点来看看整体的思想

代码实现的思想和上面说的多头注意力机制存在一点不同。两者核心的思想是相同的,但是具体实现的方式存在区别

两种实现思想对比

第一种思路:上文说QKV是同时将X映射到不同的维度好几份,将X中的词向量维度通过映射在QKV中得到降维目的,降低计算难度。并且让不同份的QKV参数矩阵能够关注X不同的特征,从而其让X的特征值的寻找能够更加细致彻底(如下图)

这里介绍另外一种思路(代码实现采用的,两者核心思路是一样的) :

第二种思路:这里我们不再将X利用参数矩阵直接映射成的QKV,从而实现降维以及多头的效果。而是将词向量维度分为头数*新词向量维度(即 :词向量维度=头数*新词向量维度),此时也就实现了降维、多头两个效果

如此分离之后,有一种更好的理解方式:将词向量特征分为不同的组,将不同的组分给不同的注意力机制模型学习,从而让模型专注学习每个组对应的词向量特征,从而使得模型学习效果更好

(下图将词向量在特征维度上分为四头)

在代码具体实现时,考虑到两者最终的效果差不多,但是上面的这个算法实现起来效率会差很多(参数计算量更大了),所以我们采用的策略会是第二种思路

总结

撰写文章不易,如果文章能帮助到大家,大家可以点点赞、收收藏呀~

十二月的猫在这里祝大家学业有成、事业顺利、情到财来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/316096.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

virtualbox 网络设置实现主机和虚拟机互相访问

前言 一般来说,virtualbox 虚拟机的上网模式是 NAT。这样虚拟机可以上网并访问宿主机,但宿主机无法访问虚拟机,也无法 ping 通。下面介绍双网卡模式,实现虚拟机和宿主机能够互相访问 ping 通。 双网卡模式 进入虚拟机的网络设置…

串联超前及对应matlab实现

串联超前校正它的本质是利用相角超前的特性提高系统的相角裕度。传递函数为:下面将以一个实际的例子,使用matlab脚本,实现其校正后的相位裕度≥60。

微服架构基础设施环境平台搭建 -(六)Kubesphere 部署Redis服务 设置访问Redis密码

微服架构基础设施环境平台搭建 -(六)Kubesphere 部署Redis服务 & 设置访问Redis密码 微服架构基础设施环境平台搭建 系列文章 微服架构基础设施环境平台搭建 -(一)基础环境准备 微服架构基础设施环境平台搭建 -(二…

使用CSS3 + Vue3 + js-tool-big-box工具,实现炫酷五一倒计时动效

时间过得真是飞速,很快又要到一年一度的五一劳动节啦,今年五天假,做好准备了吗?今天我们用CSS3 Vue3 一个前端工具库 js-tool-big-box来实现一个炫酷的五一倒计时动效吧。 目录 1 先制作一个CSS3样式 2 Vue3功能提前准备 3…

MySQL--mysql的安装(压缩包安装保姆级教程)

官网下载:www.mysql.com MySQL :: Download MySQL Community Server (Archived Versions) 1.MySQL下载流程: 第一步:点击download, 下滑找到MySQL community(gpl)Downloads>> 第二步:点…

《小倩》撤档五一,光线动画神话宇宙出师未捷

临近五一假期,光线动画的《小倩》突然宣布撤档,退出了五一档的争夺。 作为光线动画成立以来倾力打造的中国神话宇宙的第一部电影,《小倩》的市场关注度不言而喻,尤其是在今年五一档撞上两部大热日本动画电影,正面交锋…

VPN的基本概念

随着互联网的普及和应用的广泛,网络安全和隐私保护越来越受到人们的关注。在这个信息爆炸的时代,我们的个人信息、数据通信可能会受到各种威胁,如何保护自己的隐私和数据安全成为了一个迫切的问题。而VPN(Virtual Private Network…

pytho爬取南京房源成交价信息并导入到excel

# encoding: utf-8 # File_name: import requests from bs4 import BeautifulSoup import xlrd #导入xlrd库 import pandas as pd import openpyxl# 定义函数来获取南京最新的二手房房子成交价 def get_nanjing_latest_second_hand_prices():cookies {select_city: 320100,li…

人脸识别系统架构

目录 1. 系统架构 1.1 采集子系统 1.2 解析子系统 1.3 存储子系统 1.4 比对子系统 1.5 决策子系统 1.6 管理子系统 1.7 应用开放接口 2. 业务流程 2.1 人脸注册 2.2 人脸验证 2.2.1 作用 2.2.2 特点 2.2.3 应用场景 2.3 人脸辨识 2.3.1 作用 2.3.2 特点 2.3.3…

大珩PPT助手一键颜色设置

大珩PPT助手最新推出的一键设置文字颜色和背景色功能,为用户在创建演示文稿时带来了更便捷、高效的体验。这一功能使用户能够轻松调整演示文稿中文字的颜色和幻灯片的背景色,以满足不同场合和主题的需要。 以下是该功能的几个关键特点和优势&#xff1a…

RAG-Driver: 多模态大语言模型中具有检索增强上下文学习的通用驱动解释

RAG-Driver: 多模态大语言模型中具有检索增强上下文学习的通用驱动解释 摘要Introduction RAG-Driver: Generalisable Driving Explanations with Retrieval-Augmented In-Context Learning in Multi-Modal Large Language Model. 摘要 由“黑箱”模型驱动的机器人需要提供人类…

JAVA实现easyExcel模版导出

easyExcel文档 模板注意&#xff1a; 用 {} 来表示你要用的变量 &#xff0c;如果本来就有"{“,”}" &#xff0c;特殊字符用"{“,”}"代替{} 代表普通变量{.}代表是list的变量 添加pom依赖 <dependency><groupId>com.alibaba</groupId&g…

Docker有哪些常见命令?什么是Docker数据卷?

喜欢就点击上方关注我们吧&#xff01; 哈喽&#xff0c;大家好呀&#xff01;这里是码农后端。上一篇我们介绍了Docker的安装以及腾讯云镜像加速源的配置。本篇将带你学习Docker的常见命令、数据卷及自定义镜像等相关知识。 1、什么是镜像与容器&#xff1f; 利用Docker安装应…

电容的理论基础

目录 1.电容的本质&#xff1a; 2.电容量的大小 2.1电容的单位 2.2电容的决定式 ​编辑3.电容的特点 5.电容器的类型 6.电容实际的电路模型 7.安装方法 ​编辑8.电容值 9.电容的耐压、封装 10.阻抗-频率特性 11.频率特性 12.等效串联电组ESR 13.电容器的温度特性…

MATLAB 数据类型

MATLAB 数据类型 MATLAB 不需要任何类型声明或维度语句。每当 MATLAB 遇到一个新的变量名&#xff0c;它就创建变量并分配适当的内存空间。 如果变量已经存在&#xff0c;那么MATLAB将用新内容替换原始内容&#xff0c;并在必要时分配新的存储空间。 例如&#xff0c; Tota…

命令执行。

命令执行 在该项目的readme中&#xff0c;描述了怎么去调用的flink 通过java原生的runtime来调用flink&#xff0c;下一步就是去看看具体的调用过程了&#xff0c;是否存在可控的参数 找到具体提交命令的类方法CommandRpcClinetAdapterImpl#submitJob() 这里要确定command&am…

TiDB 6.x 新特性解读 | Collation 规则

对数据库而言&#xff0c;合适的字符集和 collation 规则能够大大提升使用者运维和分析的效率。TiDB 从 v4.0 开始支持新 collation 规则&#xff0c;并于 TiDB 6.0 版本进行了更新。本文将深入解读 Collation 规则在 TiDB 6.0 中的变更和应用。 引 这里的“引”&#xff0c;…

用Redis实现获取验证码,外加安全策略

安全策略 一小时内只能获取三次&#xff0c;一天内只能获取五次 Redis存储结构 代码展示 import cn.hutool.core.util.RandomUtil; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.jupiter.api.Test; import org.spri…

SD8942 600KHz、16V、2A同步降压转换器芯片IC

一般说明 该SD8942是一个完全集成&#xff0c;高效率2A同步整流降压转换器。SD8942在宽输出电 流负载范围内以高效率运行。该器件提供两种工作模式&#xff0c;PWM控制和PFM模式开关控制&#xff0c;它允许在更宽的负载范围内的高效率。 该SD8942需要一个现成的标…

Flink面试(1)

1.Flink 的并行度的怎么设置的&#xff1f; Flink设置并行度的几种方式 1.代码中设置setParallelism() 全局设置&#xff1a; 1 env.setParallelism(3);  算子设置&#xff08;部分设置&#xff09;&#xff1a; 1 sum(1).setParallelism(3) 2.客户端CLI设置&#xff0…