人工智能|深度学习——多模态条件机制 Cross Attention 原理及实现

一、引入

虽然之前写过 Attention 的文章,但现在回头看之前写的一些文章,感觉都好啰嗦,正好下一篇要写的 Stable Diffusion 中有 cross-attention,索性就再单拎出来简单说一下 Attention 吧,那么这篇文章的作用有两个:第一是为 Stable Diffusion 做补充,第二是为后续的 Vision Transformer 和 Swin Transformer 做铺垫。

为了保证篇幅开头的完整性,还得啰嗦一下 Transformer,它最开始提出是针对nlp领域的,在此之前除了seq2seq这种encoder-decoder架构,大家主要还是用的rnn、lstm这种时序网络,像rnn系列网络它是有问题的,首先就是它记忆的长度是有限的,其次是无法并行化计算,也就是必须要先计算xt时刻的数据才能计算时刻xt+1,这就导致效率低下。针对这些问题,Google就提出了 Transformer,在 Transformer 中有两个非常重要的模块:Self Attention 和 Multi-Head Attention,本文会先介绍 Attention 的基本思想,然后再对 Self Attention 和 Multi-Head Attention 进行概述,最后再讲本文的主题 Cross Attention,其实 Cross Attention 非常简单,不要被它的名字吓到,一定要理解透彻前面的 Multi-Head Attention。

二、Attention 思想

注意力机制的核心目标是从众多信息中选择出对当前任务目标更关键的信息,将注意力放在上面。其本质思想就是【从大量信息中】【有选择的筛选出】【少量重要信息】并将注意力【聚焦到这些重要信息上】,【忽略大多不重要的信息】。聚焦的过程体现在【权重系数】的计算上,权重越大越聚焦于其对应的value值上。即权重代表了信息的重要性,而value是其对应的信息。

  • Q是Query,是输入的信息,即当前任务的目标,用于和key进行匹配;

  • K和V分别是Key和Value,一般是相同的数据,比如原始文本经过Embedding后的表征;

  • 通过计算Q与K之间的相关性,得到权重a,再将权重a进行类似于softmax的归一化操作,表示不同的key对于Q的重要程度,或者说权重a越大,我们就会把更多的注意力放到其对应的value上;

  • 用权重a再与对应的Value相乘,意思是我们从Value中提取到的重要信息,或者说是对Q有用的信息;

  • 加权后的结果再求和就得到了针对Query的Attention输出,用新的输出代替原来的Q参与之后的一系列运算。

我们以机器翻译为例进一步加深理解,假设有文本“汤姆追逐杰瑞”,方便起见我们规定词库单词就为tom、chase、jerry,当我们对“汤姆”进行翻译的时候,套用上述 Attention 机制:

三、Self Attention

我们可以观察上面的传统 Attention 机制,我们可以发现每个词只表示自身的含义,不包含上下文的语义信息。而 Self Attention 则顾名思义,它指的是关注输入序列元素之间的关系,也就是说每个元素都有自己的Q、K、V,经过 Self Attention 对词向量进行重构后,使得词向量即包含自己的信息,又综合考虑了上下文的语义信息,如下图所示:

四、Multi-Head Attention

图片

在理解了 Self Attention 之后,Multi-Head Attention 就很容易了,它相当于 h 个不同的 Self Attention 的集成,说白了就是对其的堆叠。

Multi-Head Attention的优点:

  • 多头保证了我们可以注意到不同子空间的信息,捕捉到更加丰富的特征信息。

  • 能够捕捉到特征的多样性,说白了就是因为有多头,可以从多个角度去理解内容。

    • 换句话说,经过注意力之后的矩阵会有自己理解的语义信息,那么多个头就会有多个不同的理解。

  • 通过注意力可以充分的解读上下文的语义信息,能够充分的带入到一个场景中做理解。

五、Padding Mask

在做注意力的时候,我们还需要进行 padding mask 来消除 padding 部分的影响,因为有softmax的存在,padding项的注意力也会作为x^{i}进行缩放,首先对padding项添加注意力本身就不合理,其次它作为x^{i}就相当于也会产生权重,所以要消除 padding 带来的影响。

具体而言,我们会在输入序列中定位到 padding 的位置,然后标记为1,其余标记为0,然后构建一个与 attention 矩阵同维度的mask矩阵,其中填充位置对应元素为1,其它位置对应元素为0,关键代码如下:

# 构建padding mask矩阵
pad_mask = input_ids.eq(0) # 逻辑矩阵pad_mask:将填充位置标记为True,其他位置标记为False [batch_size, seq_len]
# 增加维度,和 QK^T 后的att权重维度等同 [batch_szie, seq_len, seq_len]
pad_mask = pad_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) # (batch_size, num_heads, seq_len, seq_len)
att_weights = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # 点积操作# 因为是多头,所以mask矩阵维度要扩充到4维  [batch_size, seq_len, seq_len] -> [batch_size, nums_head, seq_len, seq_len]
pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
att_weights.masked_fill_(pad_mask, float('-inf')) # 将填充位置对应的元素设置为负无穷
att_weights = torch.softmax(att_weights, dim=-1) # 在最后一个维度上进行softmaxcontext = torch.matmul(att_weights, V) # (batch_size, num_heads, seq_len, emb_dim)

六、Cross Attention

理解上面了,cross-attention就更简单了,它用于处理两个不同模态序列之间的关联,在多模态场景中用于将文本和图像等不同类型的数据进行交互处理:

七、代码实现

7.1 self-attention

class SelfAttention(nn.Module):def __init__(self, emb_dim):super(SelfAttention, self).__init__()self.emb_dim = emb_dimself.Wq = nn.Linear(emb_dim, emb_dim, bias=False)self.Wk = nn.Linear(emb_dim, emb_dim, bias=False)self.Wv = nn.Linear(emb_dim, emb_dim, bias=False)self.fc = nn.Linear(emb_dim, emb_dim)def forward(self, x, pad_mask=None):# [batch_szie, seq_len, emb_dim] = [3, 5, 512]Q = self.Wq(x)K = self.Wk(x)V = self.Wv(x)att_weights = torch.bmm(Q, K.transpose(1, 2))   # [batch_szie, seq_len, seq_len] = [3, 5, 5]att_weights = att_weights / math.sqrt(self.emb_dim)if pad_mask is not None:att_weights = att_weights.masked_fill(pad_mask, -1e9)att_weights = F.softmax(att_weights, dim=-1)output = torch.bmm(att_weights, V)   # [batch_szie, seq_len, emb_dim] = [3, 5, 512]output = self.fc(output)return output, att_weights

7.2 Multi-Head Attention

class MultiHeadAttention(nn.Module):def __init__(self, emb_dim, num_heads, att_dropout=0.0):super(MultiHeadAttention, self).__init__()self.emb_dim = emb_dimself.num_heads = num_headsself.att_dropout = att_dropoutassert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"self.depth = emb_dim // num_headsself.Wq = nn.Linear(emb_dim, emb_dim, bias=False)self.Wk = nn.Linear(emb_dim, emb_dim, bias=False)self.Wv = nn.Linear(emb_dim, emb_dim, bias=False)self.fc = nn.Linear(emb_dim, emb_dim)def forward(self, x, pad_mask=None):# [batch_szie, seq_len, emb_dim] = [3, 5, 512]batch_size = x.size(0)# [batch_szie, seq_len, emb_dim] = [3, 5, 512]Q = self.Wq(x)K = self.Wk(x)V = self.Wv(x)# 分头 [batch_szie, num_heads, seq_len, depth] = [3, 8, 5, 512/8=64]Q = Q.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)K = K.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)V = V.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)# [batch_szie, num_heads, seq_len, seq_len] = [3, 8, 5, 5]att_weights = torch.matmul(Q, K.transpose(-2, -1))att_weights = att_weights / math.sqrt(self.depth)if pad_mask is not None:# 因为是多头,所以mask矩阵维度要扩充到4维  [batch_size, seq_len, seq_len] -> [batch_size, nums_head, seq_len, seq_len]pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)att_weights = att_weights.masked_fill(pad_mask, -1e9)att_weights = F.softmax(att_weights, dim=-1)# 自己的多头注意力效果没有torch的好,我猜是因为它的dropout给了att权重,而不是fcif self.att_dropout > 0.0:att_weights = F.dropout(att_weights, p=self.att_dropout)# [batch_szie, num_heads, seq_len, depth] = [3, 8, 5, 64]output = torch.matmul(att_weights, V)# 不同头的结果拼接 [batch_szie, seq_len, emb_dim] = [3, 5, 512]output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_dim)output = self.fc(output)return output, att_weights

7.3 Cross_MultiAttention

class Cross_MultiAttention(nn.Module):def __init__(self, in_channels, emb_dim, num_heads, att_dropout=0.0, aropout=0.0):super(Cross_MultiAttention, self).__init__()self.emb_dim = emb_dimself.num_heads = num_headsself.scale = emb_dim ** -0.5assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"self.depth = emb_dim // num_headsself.proj_in = nn.Conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)self.Wq = nn.Linear(emb_dim, emb_dim)self.Wk = nn.Linear(emb_dim, emb_dim)self.Wv = nn.Linear(emb_dim, emb_dim)self.proj_out = nn.Conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)def forward(self, x, context, pad_mask=None):''':param x: [batch_size, c, h, w]:param context: [batch_szie, seq_len, emb_dim]:param pad_mask: [batch_size, seq_len, seq_len]:return:'''b, c, h, w = x.shapex = self.proj_in(x)   # [batch_size, c, h, w] = [3, 512, 512, 512]x = rearrange(x, 'b c h w -> b (h w) c')   # [batch_size, h*w, c] = [3, 262144, 512]Q = self.Wq(x)  # [batch_size, h*w, emb_dim] = [3, 262144, 512]K = self.Wk(context)  # [batch_szie, seq_len, emb_dim] = [3, 5, 512]V = self.Wv(context)Q = Q.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)  # [batch_size, num_heads, h*w, depth]K = K.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)  # [batch_size, num_heads, seq_len, depth]V = V.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)# [batch_size, num_heads, h*w, seq_len]att_weights = torch.einsum('bnid,bnjd -> bnij', Q, K)att_weights = att_weights * self.scaleif pad_mask is not None:# 因为是多头,所以mask矩阵维度要扩充到4维  [batch_size, h*w, seq_len] -> [batch_size, nums_head, h*w, seq_len]pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)att_weights = att_weights.masked_fill(pad_mask, -1e9)att_weights = F.softmax(att_weights, dim=-1)out = torch.einsum('bnij, bnjd -> bnid', att_weights, V)out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_dim)   # [batch_size, h*w, emb_dim]print(out.shape)out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)   # [batch_size, c, h, w]out = self.proj_out(out)   # [batch_size, c, h, w]return out, att_weights

7.4 Cross Attention

class CrossAttention(nn.Module):def __init__(self, in_channels, emb_dim, att_dropout=0.0, aropout=0.0):super(CrossAttention, self).__init__()self.emb_dim = emb_dimself.scale = emb_dim ** -0.5self.proj_in = nn.Conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)self.Wq = nn.Linear(emb_dim, emb_dim)self.Wk = nn.Linear(emb_dim, emb_dim)self.Wv = nn.Linear(emb_dim, emb_dim)self.proj_out = nn.Conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)def forward(self, x, context, pad_mask=None):''':param x: [batch_size, c, h, w]:param context: [batch_szie, seq_len, emb_dim]:param pad_mask: [batch_size, seq_len, seq_len]:return:'''b, c, h, w = x.shapex = self.proj_in(x)   # [batch_size, c, h, w] = [3, 512, 512, 512]x = rearrange(x, 'b c h w -> b (h w) c')   # [batch_size, h*w, c] = [3, 262144, 512]Q = self.Wq(x)  # [batch_size, h*w, emb_dim] = [3, 262144, 512]K = self.Wk(context)  # [batch_szie, seq_len, emb_dim] = [3, 5, 512]V = self.Wv(context)# [batch_size, h*w, seq_len]att_weights = torch.einsum('bid,bjd -> bij', Q, K)att_weights = att_weights * self.scaleif pad_mask is not None:# [batch_size, h*w, seq_len]att_weights = att_weights.masked_fill(pad_mask, -1e9)att_weights = F.softmax(att_weights, dim=-1)out = torch.einsum('bij, bjd -> bid', att_weights, V)   # [batch_size, h*w, emb_dim]out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)   # [batch_size, c, h, w]out = self.proj_out(out)   # [batch_size, c, h, w]print(out.shape)return out, att_weights

7.5 main

# coding:utf-8
# @Email: wangguisen@donews.com
# @Time: 2023/3/22 22:58
# @File: att_test.py
'''
Self Attention
Multi-Head Attention
Cross Attention
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from torch.nn import MultiheadAttentionif __name__ == '__main__':'''''''''假设词表映射后输入 batch_size = 3seq_len = max_len = 5pad = 0emb_dim = 512'''batch_size = 3seq_len = 5emb_dim = 512# 本例子则词表大小为 301vocab_size = 301input_ids = torch.tensor([[100, 200, 300, 300, 0],[22, 33, 44, 0, 0],[66, 55, 66, 30, 0]], dtype=torch.long)pad_mask = input_ids.eq(0)  # 逻辑矩阵pad_mask:将填充位置标记为True,其他位置标记为False# pad_mask = pad_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len] = [3, 5, 5]inputs = nn.Embedding(vocab_size, embedding_dim=emb_dim)(input_ids)   # [batch_szie, seq_len, emb_dim] = [3, 5, 512]# self_att = SelfAttention(emb_dim=emb_dim)# self_att(inputs, pad_mask=pad_mask)# multi_att = MultiHeadAttention(emb_dim=emb_dim, num_heads=8)# multi_att(inputs, pad_mask=pad_mask)# 定义图片数据  [batch_size, c, h, w]input_img = torch.randn((3, 3, 512, 512))pad_mask = pad_mask.unsqueeze(1).expand(batch_size, 512*512, seq_len)# cross_att = Cross_MultiAttention(in_channels=3, emb_dim=emb_dim, num_heads=8, att_dropout=0.0, aropout=0.0)# cross_att(x=input_img, context=inputs, pad_mask=pad_mask)cross_att = CrossAttention(in_channels=3, emb_dim=emb_dim, att_dropout=0.0, aropout=0.0)cross_att(x=input_img, context=inputs, pad_mask=pad_mask)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/368441.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

46. 全排列

计算数组的全排列 给定一个不含重复数字的数组 nums,返回其所有可能的全排列。你可以按任意顺序返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]] 示例 2: 输入&a…

护网蓝队面试

一、sql注入分类 **原理:**没有对用户输入项进行验证和处理直接拼接到查询语句中 查询语句中插⼊恶意SQL代码传递后台sql服务器分析执行 **从注入参数类型分:**数字型注入、字符型注入 **从注入效果分:**报错注入、布尔注入、延时注入、联…

day04-组织架构

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1.组织架构-树组件应用树形组件-用层级结构展示信息,可展开或折叠。 2.组织架构-树组件自定义结构3.组织架构-获取组织架构数据4.组织架构-递归转化树形…

02-部署LVS-DR群集

1.LVS-DR工作原理 LVS-DR模式,Director Server作为群集的访问入口,不作为网购使用,节点Director Server 与 Real Server 需要在同一个网络中,返回给客户端的数据不需要经过Director Server 为了响应对整个群集的访问,…

Docker镜像加速配置

由于当前运营商网络问题,可能会导致您拉取 Docker Hub 镜像变慢,索引可以配置阿里云镜像加速器。阿里云登录 - 欢迎登录阿里云,安全稳定的云计算服务平台 每个人镜像地址都不一样,需要登陆阿里云自行查看,地址在上面&a…

llama-factory训练RLHF-PPO模型

理论上RLHF(强化学习)效果比sft好,也更难训练。ppo有采用阶段,步骤比较多,训练速度很慢. 记录下工作中使用llama-factory调试rlhf-ppo算法流程及参数配置,希望对大家有所帮助. llama-factory版本: 0.8.2 一 rlhf流程 ppo训练流程图如下, 会…

油猴Safari浏览器插件:Tampermonkey for Mac 下载

Tampermonkey 是一个强大的浏览器扩展,用于运行用户脚本,这些脚本可以自定义和增强网页的功能。它允许用户在网页上执行各种自动化任务,比如自动填写表单、移除广告、改变页面布局等。适用浏览器: Tampermonkey 适用于多数主流浏览…

Golang | Leetcode Golang题解之第201题数字范围按位与

题目&#xff1a; 题解&#xff1a; func rangeBitwiseAnd(m int, n int) int {for m < n {n & (n - 1)}return n }

二叉树与堆相关的时间复杂度问题

目录 满二叉树与完全二叉树高度h和树中节点个数N的关系 向上调整算法&#xff1a; 介绍&#xff1a; 复杂度推导&#xff1a; 向下调整算法&#xff1a; 介绍&#xff1a; 复杂度推导&#xff1a; 向上调整建堆&#xff1a; 介绍&#xff1a; 复杂度推导&#xff1a;…

Web Based Quiz System v1.0 SQL 注入漏洞(CVE-2022-32991)

前言 CVE-2022-32991 是一个影响 Web Based Quiz System v1.0 的 SQL 注入漏洞。这个漏洞存在于 welcome.php 文件中的 eid 参数处。攻击者可以通过此漏洞在数据库中执行任意 SQL 语句&#xff0c;从而获取、修改或删除数据库中的数据。 具体细节如下&#xff1a; 攻击向量&…

无人机森林火灾解决方案

森林火灾解决方案 森林火灾特点 森林火灾发生突然、蔓延迅速、难以控制&#xff0c;应对难度系 数很高&#xff0c;扑救工作十分困难 救援面临的挑战 • 林区交通不便&#xff0c; 山高坡陡&#xff0c; 沟壑纵横&#xff0c;难以及时侦查、 定位、扑灭 • 火灾发生的区域…

基于opencv的斜光测距及python实现

1.前言 最近做了一个基于opencv的斜光测距的小项目&#xff0c;东西不多&#xff0c;但是很有意思&#xff0c;值得拿出来学一学。项目里面需要比较精确的定位功能&#xff0c;将前人matlab代码移植到python上&#xff0c;并且做了一些优化&#xff0c;简化逻辑(毕竟我是专业的…

马工程刑法期末复习笔记重点2

马工程刑法期末复习笔记重点2

【JavaWeb程序设计】环境配置和Web工程的创建

目录 一、安装JDK、Tomcat&#xff0c;进行测试Tomcat能否正常启动。 二、修改Tomcat端口为8976&#xff0c;重新进行测试 三、使用集成开发环境Intelligent Idea&#xff0c;绑定JDK和Tomcat&#xff0c;建立站点&#xff0c;并测试 四、编写一个简单的html页面&#xff0…

微信小程序遮罩层显示

效果展示&#xff1a; wxml页面&#xff1a; <view classmodal-mask wx:if{{showModal}}><view class"modal-container"><view classmodal-content></view><view classmodal-footer bindtap"closeImage">//这个/images/ind…

SpringMVC的架构有什么优势?——控制器(一)

文章目录 控制器(Controller)1. 控制器(Controller)&#xff1a;2. 请求映射(Request Mapping)&#xff1a;3. 参数绑定(Request Parameters Binding)&#xff1a;4. 视图解析器(View Resolver)&#xff1a;5. 数据绑定(Data Binding)&#xff1a;6. 表单验证(Form Validation)…

Workerman在线客服系统源码,附搭建教程

源码介绍&#xff1a; Workerman在线客服系统源码。 workerman是一个高性能的PHP socket 服务器框架&#xff0c;workerman基于PHP多进程以及libevent事件轮询库&#xff0c;PHP开发者只要实现一两个接口&#xff0c;便可以开发出自己的网络应用&#xff0c;例如Rpc服务、聊天…

leetCode.98. 验证二叉搜索树

leetCode.98. 验证二叉搜索树 题目描述 代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(n…

秋招Java后端开发冲刺——并发篇2(JMM与锁机制)

本文对Java的内存管理模型、volatile关键字和锁机制进行详细阐述&#xff0c;包括synchronized关键字、Lock接口及其实现类ReentrantLock、AQS等的实现原理和常见方法。 一、JMM&#xff08;Java内存模型&#xff09; 1. 介绍 JMM定义了共享内存中多线程程序读写操作的行为规…

51单片机第18步_将TIM0用作13位定时器

本章重点学习将TIM0用作13位定时器。 1、定时器0工作在模式0框图 2、定时器0工作在模式0举例 1、Keil C51中有一些关键字&#xff0c;需要牢记&#xff1a; interrupt 0&#xff1a;指定当前函数为外部中断0&#xff1b; interrupt 1&#xff1a;指定当前函数为定时器0中断…