【深度学习】注意力机制(二)

本文介绍一些注意力机制的实现,包括EA/MHSA/SK/DA/EPSA。

【深度学习】注意力机制(一)

【深度学习】注意力机制(三)

目录

一、EA(External Attention)

二、Multi Head Self Attention

三、SK(Selective Kernel Networks)

四、DA(Dual Attention)

五、EPSA(Efficient Pyramid Squeeze Attention)


一、EA(External Attention)

EA可以关注全局的空间信息,论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass External_attention(nn.Module):'''Arguments:c (int): The input and output channel number.'''def __init__(self, c):super(External_attention, self).__init__()self.conv1 = nn.Conv2d(c, c, 1)self.k = 64self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        self.conv2 = nn.Sequential(nn.Conv2d(c, c, 1, bias=False),norm_layer(c))        for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.Conv1d):n = m.kernel_size[0] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, _BatchNorm):m.weight.data.fill_(1)if m.bias is not None:m.bias.data.zero_()def forward(self, x):idn = xx = self.conv1(x)b, c, h, w = x.size()n = h*wx = x.view(b, c, h*w)   # b * c * n attn = self.linear_0(x) # b, k, nattn = F.softmax(attn, dim=-1) # b, k, nattn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, nx = self.linear_1(attn) # b, c, nx = x.view(b, c, h, w)x = self.conv2(x)x = x + idnx = F.relu(x)return x

二、Multi Head Self Attention

注意力机制的经典,Transformer的基石。论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ScaledDotProductAttention(nn.Module):'''Scaled dot-product attention'''def __init__(self, d_model, d_k, d_v, h,dropout=.1):''':param d_model: Output dimensionality of the model:param d_k: Dimensionality of queries and keys:param d_v: Dimensionality of values:param h: Number of heads'''super(ScaledDotProductAttention, self).__init__()self.fc_q = nn.Linear(d_model, h * d_k)self.fc_k = nn.Linear(d_model, h * d_k)self.fc_v = nn.Linear(d_model, h * d_v)self.fc_o = nn.Linear(h * d_v, d_model)self.dropout=nn.Dropout(dropout)self.d_model = d_modelself.d_k = d_kself.d_v = d_vself.h = hself.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):'''Computes:param queries: Queries (b_s, nq, d_model):param keys: Keys (b_s, nk, d_model):param values: Values (b_s, nk, d_model):param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).:return:'''b_s, nq = queries.shape[:2]nk = keys.shape[1]q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)if attention_weights is not None:att = att * attention_weightsif attention_mask is not None:att = att.masked_fill(attention_mask, -np.inf)att = torch.softmax(att, -1)att=self.dropout(att)out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)out = self.fc_o(out)  # (b_s, nq, d_model)return out

三、SK(Selective Kernel Networks)

SK是通道注意力机制。论文地址:论文连接

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDictclass SKAttention(nn.Module):def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):super().__init__()self.d=max(L,channel//reduction)self.convs=nn.ModuleList([])for k in kernels:self.convs.append(nn.Sequential(OrderedDict([('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),('bn',nn.BatchNorm2d(channel)),('relu',nn.ReLU())])))self.fc=nn.Linear(channel,self.d)self.fcs=nn.ModuleList([])for i in range(len(kernels)):self.fcs.append(nn.Linear(self.d,channel))self.softmax=nn.Softmax(dim=0)def forward(self, x):bs, c, _, _ = x.size()conv_outs=[]### splitfor conv in self.convs:conv_outs.append(conv(x))feats=torch.stack(conv_outs,0)#k,bs,channel,h,w### fuseU=sum(conv_outs) #bs,c,h,w### reduction channelS=U.mean(-1).mean(-1) #bs,cZ=self.fc(S) #bs,d### calculate attention weightweights=[]for fc in self.fcs:weight=fc(Z)weights.append(weight.view(bs,c,1,1)) #bs,channelattention_weughts=torch.stack(weights,0)#k,bs,channel,1,1attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1### fuseV=(attention_weughts*feats).sum(0)return V

四、DA(Dual Attention)

DA融合了通道注意力和空间注意力机制。论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttentionclass PositionAttentionModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)def forward(self,x):bs,c,h,w=x.shapey=self.cnn(x)y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,cy=self.pa(y,y,y) #bs,h*w,creturn yclass ChannelAttentionModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)def forward(self,x):bs,c,h,w=x.shapey=self.cnn(x)y=y.view(bs,c,-1) #bs,c,h*wy=self.pa(y,y,y) #bs,c,h*wreturn yclass DAModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)def forward(self,input):bs,c,h,w=input.shapep_out=self.position_attention_module(input)c_out=self.channel_attention_module(input)p_out=p_out.permute(0,2,1).view(bs,c,h,w)c_out=c_out.view(bs,c,h,w)return p_out+c_out

五、EPSA(Efficient Pyramid Squeeze Attention)

论文:论文地址

如下图:

代码如下(代码连接):

import torch.nn as nnclass SEWeightModule(nn.Module):def __init__(self, channels, reduction=16):super(SEWeightModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)self.sigmoid = nn.Sigmoid()def forward(self, x):out = self.avg_pool(x)out = self.fc1(out)out = self.relu(out)out = self.fc2(out)weight = self.sigmoid(out)return weightdef conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):"""standard convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class PSAModule(nn.Module):def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):super(PSAModule, self).__init__()self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,stride=stride, groups=conv_groups[0])self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,stride=stride, groups=conv_groups[1])self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,stride=stride, groups=conv_groups[2])self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,stride=stride, groups=conv_groups[3])self.se = SEWeightModule(planes // 4)self.split_channel = planes // 4self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.shape[0]x1 = self.conv_1(x)x2 = self.conv_2(x)x3 = self.conv_3(x)x4 = self.conv_4(x)feats = torch.cat((x1, x2, x3, x4), dim=1)feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])x1_se = self.se(x1)x2_se = self.se(x2)x3_se = self.se(x3)x4_se = self.se(x4)x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)attention_vectors = self.softmax(attention_vectors)feats_weight = feats * attention_vectorsfor i in range(4):x_se_weight_fp = feats_weight[:, i, :, :]if i == 0:out = x_se_weight_fpelse:out = torch.cat((x_se_weight_fp, out), 1)return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/215367.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用wire重构商品微服务

一.wire简介 Wire 是一个轻巧的Golang依赖注入工具。它由Go Cloud团队开发,通过自动生成代码的方式在编译期完成依赖注入。 依赖注入是保持软件 “低耦合、易维护” 的重要设计准则之一。 此准则被广泛应用在各种开发平台之中,有很多与之相关的优秀工…

Vue3封装一个轮播图组件

先看效果 编写组件代码 CarouselChart.vue <template><div classimg-box><el-button clickpreviousImages v-ifprops.showBtn>←</el-button><div classimg><div styledisplay: flex;gap: 20px idmove><imgclassimg-item v-for(item…

C语言--clock()时间函数【详细介绍】

一.clock()时间函数介绍 在 C/C 中&#xff0c;clock() 函数通常用于处理和测量程序运行时间&#xff08;时钟时间&#xff09;。它是一种数据类型&#xff0c;表示 CPU 执行指定任务所耗费的“时钟计数”数量&#xff0c;单位为“时钟周期”。 这个函数通常包含在 time.h 头文…

每日一题,头歌平台c语言题目

任务描述 题目描述:输入一个字符串&#xff0c;输出反序后的字符串。 相关知识&#xff08;略&#xff09; 编程要求 请仔细阅读右侧代码&#xff0c;结合相关知识&#xff0c;在Begin-End区域内进行代码补充。 输入 一行字符 输出 逆序后的字符串 测试说明 样例输入&…

大语言模型有什么意义?亚马逊训练自己的大语言模型有什么用?

近年来&#xff0c;大语言模型的崭露头角引起了广泛的关注&#xff0c;成为科技领域的一项重要突破。而在这个领域的巅峰之上&#xff0c;亚马逊云科技一直致力于推动人工智能的发展。那么&#xff0c;作为一家全球科技巨头&#xff0c;亚马逊为何会如此注重大语言模型的研发与…

解决夜神模拟器与Android studio自动断开的问题

原因&#xff1a;夜神模拟器的adb版本和Android sdk的adb版本不一致 解决办法&#xff1a; 1.找到android的sdk &#xff08;1&#xff09;File--->Project Structure (2)SDK Location:记下sdk的位置 2.找到sdk中的adb文件 SDK-->platform-tools-->adb.exe 3.复制…

通用的AGI 安全风险

传统安全风险 平台基础设施安全风险 模型与数据层安全风险 应用层安全风险 平台基础设施安全风险 &#xff08;1&#xff09;物理攻击&#xff1a;机房管控不到位 &#xff08;2&#xff09;网络攻击 &#xff08;3&#xff09;计算环境&#xff1a;自身安全漏洞&#xf…

基于大语言模型的复杂任务认知推理算法CogTree

近日&#xff0c;阿里云人工智能平台PAI与华东师范大学张伟教授团队合作在自然语言处理顶级会议EMNLP2023上发表了基于认知理论所衍生的CogTree认知树生成式语言模型。通过两个系统&#xff1a;直觉系统和反思系统来模仿人类产生认知的过程。直觉系统负责产生原始问题的多个分解…

Springboot入门篇

一、概述 Spring是一个开源框架&#xff0c;2003 年兴起的一个轻量级的Java 开发框架&#xff0c;作者Rod Johnson 。Spring是为了解决企业级应用开发的复杂性而创建的&#xff0c;简化开发。 1.1对比 对比一下 Spring 程序和 SpringBoot 程序。如下图 坐标 Spring 程序中的…

【Hadoop_04】HDFS的API操作与读写流程

1、HDFS的API操作1.1 客户端环境准备1.2 API创建文件夹1.3 API上传1.4 API参数的优先级1.5 API文件夹下载1.6 API文件删除1.7 API文件更名和移动1.8 API文件详情和查看1.9 API文件和文件夹判断 2、HDFS的读写流程&#xff08;面试重点&#xff09;2.1 HDFS写数据流程2.2 网络拓…

软件设计师——软件工程(一)

&#x1f4d1;前言 本文主要是【软件工程】——软件设计师——软件工程的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304…

phpstorm中使用 phpunit 时的配置和代码覆盖率测试注意点

初始化一个composer项目&#xff0c;composer.json配置文件如下 {"name": "zingfront/questions-php","type": "project","require": {"php": "^7.4"},"require-dev": {"phpunit/phpun…

Mac安装DevEco Studio

下载 首先进入鸿蒙开发者官网&#xff0c;顶部导航栏选择开发->DevEco Studio 根据操作系统下载不同版本&#xff0c;其中Mac(X86)为英特尔芯片&#xff0c;Mac(ARM)为M芯片。 安装 下载完毕后&#xff0c;开始安装。 点击Agree 首次使用&#xff0c;请选择Do not impor…

张驰咨询:数据驱动的质量改进,六西格玛绿带在汽车业实践

尊敬的汽车行业同仁们&#xff0c;您是否曾面临生产效率低下、成本不断攀升或顾客满意度下降的困扰&#xff1f;本期专栏&#xff0c;我们将深入探讨如何通过六西格玛绿带培训&#xff0c;在汽车行业中实现过程优化和质量提升。 汽车行业的竞争日趋激烈&#xff0c;致力于提供…

Flask维护者:李辉

Flask维护者&#xff1a;李辉&#xff0c; 最近看b站的flask相关&#xff0c;发现了这个视频&#xff1a;[PyCon China 2023] 濒危 Flask 扩展拯救计划 - 李辉_哔哩哔哩_bilibili 李辉讲他在维护flask之余&#xff0c;开发了apiflask这个依托flask的框架。GitHub - apiflask/a…

融合科技,升级医疗体验——医院陪诊服务的技术创新

随着科技的迅猛发展&#xff0c;医疗服务领域也在积极借助技术手段提升患者体验。本文将探讨如何利用先进的技术代码&#xff0c;将医院陪诊服务推向新的高度。 1. 医疗预约系统的实现 # 通过Python代码实现医疗预约系统 class MedicalAppointment:def __init__(self, patie…

基于轻量级神经网络GhostNet开发构建光伏太阳能电池缺陷图像识别分析系统

工作中经常会使用到轻量级的网络模型来进行开发&#xff0c;所以平时也会常常留意使用和记录&#xff0c;在前面的博文中有过很多相关的实践工作&#xff0c;感兴趣的话可以自行移步阅读即可。 《移动端轻量级模型开发谁更胜一筹&#xff0c;efficientnet、mobilenetv2、mobil…

python的websocket方法教程

WebSocket是一种网络通信协议&#xff0c;它在单个TCP连接上提供全双工的通信信道。在本篇文章中&#xff0c;我们将探讨如何在Python中使用WebSocket实现实时通信。 websockets是Python中最常用的网络库之一&#xff0c;也是websocket协议的Python实现。它不仅作为基础组件在…

NSSCTF web刷题记录7

文章目录 [SDCTF 2022]CURL Up and Read[NUSTCTF 2022 新生赛]Translate [SDCTF 2022]CURL Up and Read 考点&#xff1a;SSRF 打开题目发现是curl命令&#xff0c;提示填入url 尝试http://www.baidu.com&#xff0c;成功跳转 将url的字符串拿去解码&#xff0c;得到json格式数…

低功耗模式的通用 MCU ACM32F0X0 系列,具有高整合度、高抗干扰、 高可靠性的特点

ACM32F0X0 系列是一款支持多种低功耗模式的通用 MCU。集成 12 位 1.6 Msps 高精度 ADC 以及比 较器、运放、触控按键控制器、段式 LCD 控制器&#xff0c;内置高性能定时器、多路 UART、LPUART、SPI、I2C 等丰富的通讯外设&#xff0c;内建 AES、TRNG 等信息安全模块&#xff0…