Pytorch 注意力机制解析与代码实现

目录

  • 什么是注意力机制
  • 1、SENet的实现
  • 2、CBAM的实现
  • 3、ECA的实现
  • 4、CA的实现

什么是注意力机制

注意力机制是深度学习常用的一个小技巧,它有多种多样的实现形式,尽管实现方式多样,但是每一种注意力机制的实现的核心都是类似的,就是注意力。

注意力机制的核心重点就是让网络关注到它更需要关注的地方。

当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为重要。

注意力机制就是实现网络自适应注意的一个方式。

一般而言,注意力机制可以分为通道注意力机制,空间注意力机制,以及二者的结合。

1、SENet的实现

SE注意力模块是一种通道注意力模块,SE模块能对输入特征图进行通道特征加强,且不改变输入特征图的大小

  1. SE模块的S(Squeeze):对输入特征图的空间信息进行压缩

  2. SE模块的E(Excitation):学习到的通道注意力信息,与输入特征图进行结合,最终得到具有通道注意力的特征图

  3. SE模块的作用是在保留原始特征的基础上,通过学习不同通道之间的关系,提高模型的表现能力。在卷积神经网络中,通过引入SE模块,可以动态地调整不同通道的权重,从而提高模型的表现能力。

实现方式:
1、对输入进来的特征层进行全局平均池化。
2、然后进行两次全连接,第一次全连接神经元个数较少,第二次全连接神经元个数和输入特征层相同。
3、在完成两次全连接后,我们再取一次Sigmoid将值固定到0-1之间,此时我们获得了输入特征层每一个通道的权值(0-1之间)。
4、在获得这个权值后,我们将这个权值乘上原输入特征层即可。

在这里插入图片描述

实现代码

import torch
from torch import nnclass SEAttention(nn.Module):def __init__(self, channel=512, reduction=16):super().__init__()# 对空间信息进行压缩self.avg_pool = nn.AdaptiveAvgPool2d(1)# 经过两次全连接层,学习不同通道的重要性self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):# 取出batch size和通道数b, c, _, _ = x.size()# b,c,w,h -> b,c,1,1 -> b,c 压缩与通道信息学习y = self.avg_pool(x).view(b, c)# b,c->b,c->b,c,1,1y = self.fc(y).view(b, c, 1, 1)# 激励操作return x * y.expand_as(x)if __name__ == '__main__':input = torch.randn(50, 512, 7, 7)se = SEAttention(channel=512, reduction=8)output = se(input)print(input.shape)print(output.shape)

SE模块是一个即插即用的模块,在上图中左边是在一个卷积模块之后直接插入SE模块,右边是在ResNet结构中添加了SE模块。

在这里插入图片描述

2、CBAM的实现

CBAM将通道注意力机制和空间注意力机制进行一个结合,相比于SENet只关注通道的注意力机制可以取得更好的效果。其实现示意图如下所示,CBAM会对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理
在这里插入图片描述
下图是通道注意力机制和空间注意力机制的具体实现方式:
图像的上半部分为通道注意力机制,通道注意力机制的实现可以分为两个部分,我们会对输入进来的单个特征层,分别进行全局平均池化和全局最大池化。之后对平均池化和最大池化的结果,利用共享的全连接层进行处理,我们会对处理后的两个结果进行相加,然后取一个sigmoid,此时我们获得了输入特征层每一个通道的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。

图像的下半部分为空间注意力机制,我们会对输入进来的特征层,在每一个特征点的通道上取最大值和平均值。之后将这两个结果进行一个堆叠,利用一次通道数为1的卷积调整通道数,然后取一个sigmoid,此时我们获得了输入特征层每一个特征点的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。
在这里插入图片描述
实现代码:

import torch
from torch import nn
from torch.nn import initclass ChannelAttention(nn.Module):def __init__(self, channel, reduction=16):super().__init__()self.maxpool = nn.AdaptiveMaxPool2d(1)self.avgpool = nn.AdaptiveAvgPool2d(1)self.se = nn.Sequential(nn.Conv2d(channel, channel // reduction, 1, bias=False),nn.ReLU(),nn.Conv2d(channel // reduction, channel, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):max_result = self.maxpool(x)avg_result = self.avgpool(x)max_out = self.se(max_result)avg_out = self.se(avg_result)output = self.sigmoid(max_out + avg_out)return outputclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)self.sigmoid = nn.Sigmoid()def forward(self, x):max_result, _ = torch.max(x, dim=1, keepdim=True)avg_result = torch.mean(x, dim=1, keepdim=True)result = torch.cat([max_result, avg_result], 1)output = self.conv(result)output = self.sigmoid(output)return outputclass CBAMAttention(nn.Module):def __init__(self, channel=512, reduction=16, kernel_size=7):super().__init__()self.ca = ChannelAttention(channel=channel, reduction=reduction)self.sa = SpatialAttention(kernel_size=kernel_size)def forward(self, x):b, c, _, _ = x.size()out = x * self.ca(x)out = out * self.sa(out)return outif __name__ == '__main__':input = torch.randn(50, 512, 7, 7)kernel_size = input.shape[2]cbam = CBAMAttention(channel=512, reduction=16, kernel_size=kernel_size)output = cbam(input)print(input.shape)print(output.shape)

3、ECA的实现

ECANet是也是通道注意力机制的一种实现形式。ECANet可以看作是SENet的改进版
ECANet的作者认为SENet对通道注意力机制的预测带来了副作用,捕获所有通道的依赖关系是低效并且是不必要的。
在ECANet的论文中,作者认为卷积具有良好的跨通道信息获取能力。

ECA模块的思想是非常简单的,它去除了原来SE模块中的全连接层,直接在全局平均池化之后的特征上通过一个1D卷积进行学习。

既然使用到了1D卷积,那么1D卷积的卷积核大小的选择就变得非常重要了,了解过卷积原理的同学很快就可以明白,1D卷积的卷积核大小会影响注意力机制每个权重的计算要考虑的通道数量。用更专业的名词就是跨通道交互的覆盖率。

如下图所示,左图是常规的SE模块,右图是ECA模块。ECA模块用1D卷积替换两次全连接。

在这里插入图片描述

代码实现:

import torch
from torch import nn
from torch.nn import initclass ECAAttention(nn.Module):def __init__(self, kernel_size=3):super().__init__()self.gap = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):y = self.gap(x)  # bs,c,1,1y = y.squeeze(-1).permute(0, 2, 1)  # bs,1,cy = self.conv(y)  # bs,1,cy = self.sigmoid(y)  # bs,1,cy = y.permute(0, 2, 1).unsqueeze(-1)  # bs,c,1,1return x * y.expand_as(x)if __name__ == '__main__':input = torch.randn(50, 512, 7, 7)eca = ECAAttention(kernel_size=3)output = eca(input)print(input.shape)print(output.shape)

4、CA的实现

在这里插入图片描述
该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化或平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W],在经过宽方向的平均池化后,获得的特征层shape为[C, H, 1],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[C, 1, W],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C, 1, H+W],利用卷积+标准化+激活函数获得特征。

之后再次分开为两个并行阶段,再将宽高分开成为:[C, 1, H]和[C, 1, W],之后进行转置。获得两个特征层[C, H, 1]和[C, 1, W]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况。乘上原有的特征就是CA注意力机制。

实现代码:

import torch
from torch import nnclass CAAttention(nn.Module):def __init__(self, channel, reduction=16):super(CAAttention, self).__init__()self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1,bias=False)self.relu = nn.ReLU()self.bn = nn.BatchNorm2d(channel // reduction)self.F_h = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,bias=False)self.F_w = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,bias=False)self.sigmoid_h = nn.Sigmoid()self.sigmoid_w = nn.Sigmoid()def forward(self, x):_, _, h, w = x.size()x_h = torch.mean(x, dim=3, keepdim=True).permute(0, 1, 3, 2)x_w = torch.mean(x, dim=2, keepdim=True)x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))out = x * s_h.expand_as(x) * s_w.expand_as(x)return outif __name__ == '__main__':input = torch.randn(50, 512, 7, 7)eca = CAAttention(512, reduction=16)output = eca(input)print(input.shape)print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/179038.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Kubernetes部署】二进制部署单Master Kurbernetes集群 超详细

二进制部署K8s 一、基本架构和系统初始化操作1.1 基本架构1.2 系统初始化操作 二、部署etcd集群2.1 证书签发Step1 下载证书制作工具Step2 创建k8s工作目录Step3 编写脚本并添加执行权限Step4 生成CA证书、etcd 服务器证书以及私钥 2.2 启动etcd服务Step1 上传并解压代码包Step…

面试算法51:节点值之和最大的路径

题目 在二叉树中将路径定义为顺着节点之间的连接从任意一个节点开始到达任意一个节点所经过的所有节点。路径中至少包含一个节点,不一定经过二叉树的根节点,也不一定经过叶节点。给定非空的一棵二叉树,请求出二叉树所有路径上节点值之和的最…

分体式离子风刀和整体式离子风刀分别有哪些优缺点

离子风刀是一种利用高速旋转的离子风扇产生的离子风来清洁和干燥物体表面的设备。根据离子风扇的安装方式,离子风刀可以分为分体式离子风刀和整体式离子风刀。下面是它们各自的优缺点: 分体式离子风刀的优点: 安装方便:分体式离子…

电压放大器可用于什么场合

电压放大器是电子器件中常见的一种放大器类型,它可以将输入信号的电压放大到更大的幅度,以满足特定应用的需求。电压放大器广泛应用于多个领域和场合,下面将详细介绍一些使用电压放大器的场景。 音频放大器:音频放大器是电压放大器…

(1)(1.11) LeddarTech Leddar One

文章目录 前言 1 连接到自动驾驶仪 2 参数说明 前言 Leddar One 激光雷达(Leddar One Lidar)是一款重量轻、价格合理的激光雷达,测距 40m,更新频率 70hz,光束为 3 度漫射。有关详细信息,请参阅数据表(datasheet)。 &#xff0…

2023-11-03 LeetCode每日一题(填充每个节点的下一个右侧节点指针 II)

2023-11-03每日一题 一、题目编号 117. 填充每个节点的下一个右侧节点指针 II二、题目链接 点击跳转到题目位置 三、题目描述 给定一个二叉树: struct Node { int val; Node *left; Node *right; Node *next; } 填充它的每个 next 指针,让这个指针…

ArcGIS Pro怎么生成高程点

一般情况下,我们从公开渠道获取到的高程数据都是DEM数据,但是如果要用到CAD等软件内则需要用到高程点,那么如何从DEM提取高程点呢,这里为大家介绍一下生成方法,希望能对你有所帮助。 数据来源 本教程所使用的数据是…

喝酒聚会摇色子小程序源码系统+石头剪刀布+大转盘 带完整的部署教程

来咯来咯,大家都知道摇色子是一种古老而受欢迎的饮酒游戏。在当代年轻人的聚会中,常常都使用摇骰子这种方法来喝酒的。今天罗峰要给大家介绍是一款非常受欢迎的小程序源码系统喝酒聚会摇色子小程序源码系统,还有石头剪刀布,大转盘…

怎么让重要文件自动备份到OneDrive?

可以让文件自动备份到OneDrive吗? OneDrive是比较受欢迎的云存储之一,主要用于存储文件和个人数据,随时随地都能够在多个设备(例如Android、台式机或笔记本电脑、平板电脑等)之间同步和共享文件。 因此&…

SSL证书在网购中的重要性

近年来,互联网的快速发展使得线上服务范围不断延伸,这其中网络购物更是在全球范围内都呈现上升趋势。然而病毒攻击,网络钓鱼攻击和恶意软件攻击无处不在,网上购物的安全性受到极大威胁。为了保护网络购物的安全,构建可…

Vue项目运行时报错:‘vue-cli-service‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件

报错原因及解决 1.package.json 文件中未定义依赖项vue/cli-service,因此在 npm install 之后并没有安装vue/cli-service 依赖; 解决:项目目录下执行命令,npm i -D vue/cli-service。2.第1步排查后,还是报同样的错&a…

【腾讯云HAI域探秘】利用HAI搭建AI绘画应用,随心所欲,畅享创作乐趣

【腾讯云HAI域探秘】利用HAI搭建AI绘画应用,随心所欲,畅享创作乐趣 1️⃣基于HAI部署的StableDiffusionWebUI快速进行AI绘画(1)创建并启动StableDiffusion应用服务器(2)使用StableDiffusionWebUI进行AI绘画…

Dubbo中的负载均衡算法之一致性哈希算法

Dubbo中的负载均衡算法之一致性哈希算法 哈希算法 假设这样一个场景,我们申请一台服务器来缓存100万的数据,这个时候是单机环境,所有的请求都命中到这台服务器。后来业务量上涨,我们的数据也从100万上升到了300万,原…

【蓝桥每日一题]-二分精确(保姆级教程 篇4) #kotori的设备 #银行贷款 #一元三次方程求解

今天讲二分精确题型 目录 题目:kotori的设备 思路: 题目:银行贷款 思路: 题目:一元三次方程求解 思路: 题目:kotori的设备 思路: 求:设备最长使用时间 二分查找&#…

数据结构(超详细讲解!!)第十九节 块链串及串的应用

1.定义 由于串也是一种线性表,因此也可以采用链式存储。由于串的特殊性(每个元素只有一个字符),在具体实现时,每个结点既可以存放一个字符,也可以存放多个字符。每个结点称为块,整个链表称为块链…

C# 发送邮件

1.安装 NuGet 包 2.代码如下 SendMailUtil using MimeKit; using Srm.CMER.Application.Contracts.CmerInfo; namespace Srm.Mail { public class SendMailUtil { public async static Task<string> SendEmail(SendEmialDto sendEmialDto,List<strin…

DC系列 DC:3

DC系列 DC:3 文章目录 DC系列 DC:3调试靶机信息收集IP端口信息收集 框架漏洞利用joomscan扫描工具利用msf工具利用(无法使用)kali漏洞库利用sqlmap利用 文件上传提权 调试靶机 点击虚拟机设置选择CD/DVD点击高级将IDE调成画面中这个选项 信息收集 IP端口信息收集 对自己网…

DI93a HESG440355R3 通过其Achilles级认证提供网络安全

DI93a HESG440355R3 通过其Achilles级认证提供网络安全 施耐德电气宣布推出Modicon M580以太网PAC (ePAC)自动化控制器&#xff0c;该控制器采用开放式以太网标准&#xff0c;通过其Achilles级认证提供网络安全。M580 ePAC使工厂操作员能够设计、实施和运行一个积极利用开放网…

量子计算与量子密码(入门级-少图版)

量子计算与量子密码 写在最前面一些可能带来的有趣的知识和潜在的收获 1、Introduction导言四个特性不确定性&#xff08;自由意志论&#xff09;Indeterminism不确定性Uncertainty叠加原理(线性)superposition (linearity)纠缠entanglement 虚数的常见基本运算欧拉公式&#x…

nodejs国内镜像及切换版本工具nvm

淘宝 NPM 镜像站&#xff08;http://npm.taobao.org&#xff09;已更换域名&#xff0c;新域名&#xff1a; Web 站点&#xff1a;https://npmmirror.com Registry Endpoint&#xff1a;https://registry.npmmirror.com 详见&#xff1a; 【望周知】淘宝 NPM 镜像换域名了&…