pytorch代码实现之SAConv卷积

SAConv卷积

SAConv卷积模块是一种精度更高、速度更快的“即插即用”卷积,目前很多方法被提出用于降低模型冗余、加速模型推理速度,然而这些方法往往关注于消除不重要的滤波器或构建高效计算单元,反而忽略了特征内部的模式冗余。
原文地址:Split to Be Slim: An Overlooked Redundancy in Vanilla Convolution

由于同一层内的许多特征具有相似却不平等的表现模式。然而,这类具有相似模式的特征却难以判断是否存在冗余或包含重要的细节信息。因此,不同于直接移除不确定的冗余特征方案,提出了一种基于Split的卷积计算单元(称之为SPConv),它运训存在相似模型冗余且仅需非常少的计算量。

SPConv结构图

首先,将输入特征拆分为representative部分与uncertain部分;然后,对于representative部分特征采用相对多的计算复杂度操作提取重要信息,对于uncertain部分采用轻量型操作提取隐含信息;最后,为重新校准与融合两组特征,作者采用了无参特征融合模块。该文所提SPConv是一种“即插即用”型模块,可用于替换现有网络中的常规卷积。

​无需任何技巧,在GPU端的精度与推理速度方面,基于SPConv的网络均可取得SOTA性能。该文主要贡献包含下面几个方面:
(1)重新对常规卷积中的特征冗余问题进行了再思考,提出了将输入分成两部分:representative与uncertain,分别针对两部分进行不同的信息提取;
(2)设计了一种“即插即用”型SPConv模块,它可以无缝替换现有网络中的常规卷积,且在精度与GPU推理速度上均可能优于SOTA性能,同时具有更少的FLOPs和参数量。

代码实现

class ConvAWS2d(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))def _get_weight(self, weight):weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)weight = weight - weight_meanstd = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)weight = weight / stdweight = self.weight_gamma * weight + self.weight_betareturn weightdef forward(self, x):weight = self._get_weight(self.weight)return super()._conv_forward(x, weight, None)def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs):self.weight_gamma.data.fill_(-1)super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs)if self.weight_gamma.data.mean() > 0:returnweight = self.weight.dataweight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)self.weight_beta.data.copy_(weight_mean)std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)self.weight_gamma.data.copy_(std)class SAConv2d(ConvAWS2d):def __init__(self,in_channels,out_channels,kernel_size,s=1,p=None,g=1,d=1,act=True,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=s,padding=autopad(kernel_size, p),dilation=d,groups=g,bias=bias)self.switch = torch.nn.Conv2d(self.in_channels,1,kernel_size=1,stride=s,bias=True)self.switch.weight.data.fill_(0)self.switch.bias.data.fill_(1)self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))self.weight_diff.data.zero_()self.pre_context = torch.nn.Conv2d(self.in_channels,self.in_channels,kernel_size=1,bias=True)self.pre_context.weight.data.fill_(0)self.pre_context.bias.data.fill_(0)self.post_context = torch.nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,bias=True)self.post_context.weight.data.fill_(0)self.post_context.bias.data.fill_(0)self.bn = nn.BatchNorm2d(out_channels)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())def forward(self, x):# pre-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)avg_x = self.pre_context(avg_x)avg_x = avg_x.expand_as(x)x = x + avg_x# switchavg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)switch = self.switch(avg_x)# sacweight = self._get_weight(self.weight)out_s = super()._conv_forward(x, weight, None)ori_p = self.paddingori_d = self.dilationself.padding = tuple(3 * p for p in self.padding)self.dilation = tuple(3 * d for d in self.dilation)weight = weight + self.weight_diffout_l = super()._conv_forward(x, weight, None)out = switch * out_s + (1 - switch) * out_lself.padding = ori_pself.dilation = ori_d# post-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)avg_x = self.post_context(avg_x)avg_x = avg_x.expand_as(out)out = out + avg_xreturn self.act(self.bn(out))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/129377.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux查端口占用的几种方式

在Linux中,你可以使用以下几种方式来查看端口的占用情况。 一、使用netstat命令 #安装netstat yum -y install net-tools #检测端口占用 netstat -npl | grep 端口# 几种常规用法 netstat -ntlp //查看当前所有tcp端口 netstat -ntulp | grep 80 //查看所有80端…

java设计模式,简单工厂和抽象工厂有什么区别?

java设计模式,简单工厂和抽象工厂有什么区别? 简单工厂模式: 这个模式本身很简单而且使用在业务较简单的情况下。一般用于小项目或者具体产品很少扩展的情况(这样工厂类才不用经常更改)。 它由三种角色组成&#xf…

MySQL 8.0.34(x64)安装笔记

一、背景 从MySQL 5.6到5.7,再到8.0,版本的跳跃不可谓不大。安装、配置的差别也不可谓不大,特此备忘。 二、过程 (1)获取MySQL 8.0社区版(MySQL Community Server)   从 官网 字样 “MySQL …

卡尔曼滤波公式推导(总结)

假设 小车在t时刻的初始状态可以用Pt(当前位置),Vt(当前速度),Ut表示加速度: 预测: 利用上一个时刻的旧状态和系统的动量模型(如加速度,速度等)…

扫码支付系统_分账收款系统_设计开发OctShop

在当今,移动支付在我们生活中已经是不可或缺的东西,以微信和支付宝为代表的扫码支付系统正在各种线下消费场景中使用,给我们日常的生活购物消费带来了不少的便利。第一些第三方的扫码支付系统更是集合了各种支付渠道,买家或消费者…

yolov5添加ECA注意力机制

ECA注意力机制简介 论文题目:ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks 论文地址:here 基本原理 🐸 ECANet的核心思想是提出了一种不降维的局部跨通道交互策略,有效避免了降维对于通道注意…

一个CVE漏洞预警知识库

CVE 0x01 免责声明 本仓库所涉及的技术、思路和工具仅供安全技术研究,任何人不得将其用于非授权渗透测试,不得将其用于非法用途和盈利,否则后果自行承担。 无exp/poc,部分包含修复方案 0x02 项目导航 2022.12 CVE-2022-3328&a…

Purple Pi OH(Debian/Ubuntu)使用python控制gpio

本文分享的是Purple Pi OH开源主板搭载Debian/Ubuntu系统如何使用python控制gpio。 Purple Pi OH作为一款兼容树莓派的开源主板,采用瑞芯微RK3566 (Cortex-A55) 四核64位超强CPU,主频最高达1.8 GHz,算力高达1Tops,支持INT8/INT16,支持Tensor…

如何保持 SSH 会话不中断?

哈喽大家好,我是咸鱼 不知道小伙伴们有没有遇到过下面的情况: 使用终端(XShell、secureCRT 或 MobaXterm 等)登录 Linux 服务器之后如果有一段时间没有进行交互,SSH 会话就会断开 如果正在执行一些非后台命令&#…

模电课设:用Multisim简单了解二极管

1 课设内容 1)测试二极管伏安特性电路; 2)二极管的整流电路及负载对输出电压和纹波的影响; 2 模型搭建 电路一:测试二极管伏安特性的电路如下图所示,结构十分简单,直流电源串联上二极管组成一…

Kafka3.0.0版本——消费者(消费者组原理)

目录 一、消费者组原理1.1、消费者组概述1.2、消费者组图解示例1.3、消费者组注意事项 一、消费者组原理 1.1、消费者组概述 Consumer Group(CG):消费者组,由多个consumer组成。形成一个消费者组的条件,是所有消费者…

娱乐时间 —— 用python将图片转为excel十字绘

最近看蛮多朋友在玩,要么只能画比较简单的,要么非常花时间。想了下本质上就是把excel对应的单元格涂色,如果能知道哪些格子要上什么颜色,用编程来实现图片转为excel十字绘应该是很方便的。 图片的每一个像素点都可以数值化&#x…

(3)MyBatis-Plus待开发

常用注解 TableName MyBatis-Plus在确定操作的表时,由BaseMapper的泛型决定即实体类型决定,且默认操作的表名和实体类型的类名一致,如果不一致则会因找不到表报异常 //向表中插入一条数据 Test public void testInsert(){User user new User(null, &…

python如何学习

功能如此强大、高效的Python,却非常的简单好学,这让学它的同学爱不释手,也让越来越多的互联网企业开始用Python来做主要的开发语言,比如谷歌、Facebook(现Meta)、豆瓣、知乎等知名互联网公司都在使用Python…

【C++】简单理解:将整数(浮点数)转换为字符串(string),将字符串(string)转换为整数(浮点数)方法

用stringstream类&#xff0c;口诀&#xff1a;过滤一下就转化 头文件#include<sstream> 例子&#xff1a;将整数12和浮点数12.34转化为字符串 int main() {int x 12;double d 12.34;string s;//创建一下对象strstringstream str;//过滤一下就转化str << x;st…

day3_C++

day3_C 思维导图用C的类完成数据结构 栈的相关操作用C的类完成数据结构 循环队列的相关操作 思维导图 用C的类完成数据结构 栈的相关操作 stack.h #ifndef STACK_H #define STACK_H#include <iostream> #include <cstring>using namespace std;typedef int datat…

SpringMVC_拦截器

4.拦截器 4.1拦截器概述 概述&#xff1a;一种动态拦截方法调用的机制&#xff0c;在SpringMVC中动态拦截控制器方法的执行实际开发中&#xff0c;静态资源&#xff08;HTML/CSS&#xff09;不需要交给框架处理&#xff0c;需要拦截的是动态资源 4.2图示 图示 4.3案例实现 …

Win11共享文件夹怎么设置

当我们在使用Win11的过程中有时会因为一些操作需要共享文件夹&#xff0c;那么Win11系统该如何设置共享文件夹呢&#xff0c;下面小编就给大家详细介绍一下Win11设置共享文件夹的方法&#xff0c;有需要的小伙伴快来和小编一起看看吧。 Win11设置共享文件夹的方法&#xff1a;…

最后一块石头的重量 II【动态规划】

最后一块石头的重量 II 有一堆石头&#xff0c;用整数数组 stones 表示。其中 stones[i] 表示第 i 块石头的重量。 每一回合&#xff0c;从中选出任意两块石头&#xff0c;然后将它们一起粉碎。假设石头的重量分别为 x 和 y&#xff0c;且 x < y。那么粉碎的可能结果如下&am…

php代理刷访问量(附源码)

众所周知&#xff0c;所谓的访问量就是用户的点击次数。当然&#xff0c;如果真只是单纯记录用户的访问次数&#xff0c;那访问量刷起来也太简单了&#xff0c;不断的刷新网页就行。因此&#xff0c;常规的网站记录访问量是通过ip来的&#xff0c;一个有效ip对应一个访问量。通…