pytorch代码实现之空间通道重组卷积SCConv

空间通道重组卷积SCConv

空间通道重组卷积SCConv,全称Spatial and Channel Reconstruction Convolution,CPR2023年提出,可以即插即用,能够在减少参数的同时提升性能的模块。其核心思想是希望能够实现减少特征冗余从而提高算法的效率。一般压缩模型的方法分为三种,分别是network pruning, weight quantization, low-rank factorization以及knowledge distillation,虽然这些方法能够达到减少参数的效果,但是往往都会导致模型性能的衰减。另一种方法就是在构建模型时利用特殊的模块或操作减少模型参数,获得轻量级的网络模型,这种方法能够在保证性能的同时达到参数减少的效果。

原文地址:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy

作者提出的SCConv包含两部分,分别是Spatial Reconstruction Unit (SRU)和Channel Reconstruction Unit (CRU),下面是SSConv的总体结构。
SCConv结构原理图
可以看出,SCConv模块设计,对于输入的特征图先利用1x1的卷积改变为适合的通道数,之后便分别是SRU和CRU两个模块对于特征图进行处理,最后在通过1x1的卷积将特征通道数恢复并进行残差操作。
SRU模块结构
CRU模块结构

代码实现如下:

import torch
import torch.nn.functional as F
import torch.nn as nn class GroupBatchnorm2d(nn.Module):def __init__(self, c_num:int, group_num:int = 16, eps:float = 1e-10):super(GroupBatchnorm2d,self).__init__()assert c_num    >= group_numself.group_num  = group_numself.gamma      = nn.Parameter( torch.randn(c_num, 1, 1)    )self.beta       = nn.Parameter( torch.zeros(c_num, 1, 1)    )self.eps        = epsdef forward(self, x):N, C, H, W  = x.size()x           = x.view(   N, self.group_num, -1   )mean        = x.mean(   dim = 2, keepdim = True )std         = x.std (   dim = 2, keepdim = True )x           = (x - mean) / (std+self.eps)x           = x.view(N, C, H, W)return x * self.gamma + self.betaclass SRU(nn.Module):def __init__(self,oup_channels:int, group_num:int = 16,gate_treshold:float = 0.5 ):super().__init__()self.gn             = GroupBatchnorm2d( oup_channels, group_num = group_num )self.gate_treshold  = gate_tresholdself.sigomid        = nn.Sigmoid()def forward(self,x):gn_x        = self.gn(x)w_gamma     = self.gn.gamma/sum(self.gn.gamma)reweigts    = self.sigomid( gn_x * w_gamma )# Gateinfo_mask   = reweigts>=self.gate_tresholdnoninfo_mask= reweigts<self.gate_tresholdx_1         = info_mask * xx_2         = noninfo_mask * xx           = self.reconstruct(x_1,x_2)return xdef reconstruct(self,x_1,x_2):x_11,x_12 = torch.split(x_1, x_1.size(1)//2, dim=1)x_21,x_22 = torch.split(x_2, x_2.size(1)//2, dim=1)return torch.cat([ x_11+x_22, x_12+x_21 ],dim=1)class CRU(nn.Module):'''alpha: 0<alpha<1'''def __init__(self, op_channel:int,alpha:float = 1/2,squeeze_radio:int = 2 ,group_size:int = 2,group_kernel_size:int = 3,):super().__init__()self.up_channel     = up_channel   =   int(alpha*op_channel)self.low_channel    = low_channel  =   op_channel-up_channelself.squeeze1       = nn.Conv2d(up_channel,up_channel//squeeze_radio,kernel_size=1,bias=False)self.squeeze2       = nn.Conv2d(low_channel,low_channel//squeeze_radio,kernel_size=1,bias=False)#upself.GWC            = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=group_kernel_size, stride=1,padding=group_kernel_size//2, groups = group_size)self.PWC1           = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=1, bias=False)#lowself.PWC2           = nn.Conv2d(low_channel//squeeze_radio, op_channel-low_channel//squeeze_radio,kernel_size=1, bias=False)self.advavg         = nn.AdaptiveAvgPool2d(1)def forward(self,x):# Splitup,low  = torch.split(x,[self.up_channel,self.low_channel],dim=1)up,low  = self.squeeze1(up),self.squeeze2(low)# TransformY1      = self.GWC(up) + self.PWC1(up)Y2      = torch.cat( [self.PWC2(low), low], dim= 1 )# Fuseout     = torch.cat( [Y1,Y2], dim= 1 )out     = F.softmax( self.advavg(out), dim=1 ) * outout1,out2 = torch.split(out,out.size(1)//2,dim=1)return out1+out2class ScConv(nn.Module):def __init__(self,op_channel:int,group_num:int = 16,gate_treshold:float = 0.5,alpha:float = 1/2,squeeze_radio:int = 2 ,group_size:int = 2,group_kernel_size:int = 3,):super().__init__()self.SRU = SRU( op_channel, group_num            = group_num,  gate_treshold        = gate_treshold )self.CRU = CRU( op_channel, alpha                = alpha, squeeze_radio        = squeeze_radio ,group_size           = group_size ,group_kernel_size    = group_kernel_size )def forward(self,x):x = self.SRU(x)x = self.CRU(x)return xif __name__ == '__main__':x       = torch.randn(1,32,16,16)model   = ScConv(32)print(model(x).shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/127913.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WebDAV之π-Disk派盘 + 天悦日记

天悦日记是一款清爽简约的日记记录工具,通过天悦日记app随时随地快速写日记,更有智能数据统计分析报表,多端同步多种备份,本地备份和基于WebDAV协议的云端备份。跨平台使用,支持多设备、多平台无差别使用。天悦日记将每一天经历都清晰记录在手机,一目了然知道曾经的经历,…

Linux初探 - 概念上的理解和常见指令的使用

目录 Linux背景 Linux发展史 GNU 应用场景 发行版本 从概念上认识Linux 操作系统的概念 用户的概念 路径与目录 Linux下的文件 时间戳的概念 常规权限 特殊权限 Shell的概念 常用指令 ls tree stat clear pwd echo cd touch mkdir rmdir rm cp mv …

uboot顶层Makefile前期所做工作说明四

一. uboot顶层 Makefile文件 uboot 顶层 Makefile&#xff0c;就是 uboot源码工程的根目录下的 Makefile文件。 本文继续对 uboot顶层 Makefile的前期准备工作进行介绍。续上一篇文章内容的学习&#xff0c;如下&#xff1a; uboot顶层Makefile前期所做工作说明三_凌肖战的博…

DAMO-YOLO训练自己的数据集,使用onnxruntime推理部署

DAMO-YOLO训练自己的数据集&#xff0c;使用onnxruntime推理部署 DAMO-YOLO 是阿里达摩院智能计算实验室开发的一种兼顾速度与精度的目标检测算法&#xff0c;在高精度的同时&#xff0c;保持了很高的推理速度。 DAMO-YOLO 是在 YOLO 框架基础上引入了一系列新技术&#xff0…

Java的环境配置

目录 window系统安装java下载JDK配置环境变量JAVA_HOME 设置PATH设置CLASSPATH 设置测试JDK是否安装成功 Linux&#xff0c;UNIX&#xff0c;Solaris&#xff0c;FreeBSD环境变量设置流行JAVA开发工具使用 Eclipse 运行第一个 Java 程序 window系统安装java 下载JDK 首先我们…

爬虫进阶-反爬破解5(selenium的优势和点击操作+chrome的远程调试能力+通过Chrome隔离实现一台电脑登陆多个账号)

目录 一、selenium的优势和点击操作 二、chrome的远程调试能力 三、通过Chrome隔离实现一台电脑登陆多个账号 一、selenium的优势和点击操作 1.环境搭建 工具&#xff1a;Chrome浏览器chromedriverselenium win用户&#xff1a;chromedriver.exe放在python.exe旁边 MacO…

Unity汉化一个插件 制作插件汉化工具

我是编程一个菜鸟&#xff0c;英语又不好&#xff0c;有的插件非常牛&#xff01;我想学一学&#xff0c;页面全是英文&#xff0c;完全不知所措&#xff0c;我该怎么办啊...尝试在Unity中汉化一个插件 效果&#xff1a; 思路&#xff1a; 如何在Unity中把一个自己喜欢的插件…

新装Ubuntu系统的一些配置

背景&#xff1a; 最近办公要在Ubuntu系统上进行&#xff0c;于是自己安装了一个Ubuntu22.04系统&#xff0c;记录下新系统做的一些基本配置。 环境 &#xff1a; 系统&#xff1a;Ubuntu-22.04内核&#xff1a;6.2.0-26-generic架构&#xff1a;x86_64 一、 配置root密码 新…

Centos7 完全断网离线环境下安装MySQL 8.0.33 图文教程

Centos7 完全断网离线环境安装MySQL 8.0.33 图文教程 1.1前言1.2 下载离线安装包1.3 将下载好的离线安装包上传到Centos 7 服务器1.3.1 方式一:联网环境下可利用rz命令进行文件上传1.3.2 方式二:断网环境下使用 XFtp 等软件工具进行上传1.4 解压安装包1.5 执行安装脚本1.6 重…

Linux TCP和UDP协议

目录 TCP协议TCP协议的面向连接1.三次握手2.四次挥手 TCP协议的可靠性1.TCP状态转移——TIME_WAIT 状态TIME_WAIT 状态存在的意义&#xff1a;&#xff08;1&#xff09;可靠的终止TCP连接。&#xff08;2&#xff09;让迟来的TCP报文有足够的时间被识别并被丢弃。 2.应答确认、…

信息安全技术概论-李剑-持续更新

图片和细节来源于 用户 xiejava1018 一.概述 随着计算机网络技术的发展&#xff0c;与时代的变化&#xff0c;计算机病毒也经历了从早期的破坏为主到勒索钱财敲诈经济为主&#xff0c;破坏方式也多种多样&#xff0c;由早期的破坏网络到破坏硬件设备等等 &#xff0c;这也…

类和对象:构造函数,析构函数与拷贝构造函数

1.类的6个默认成员函数 如果一个类中什么成员都没有&#xff0c;简称为空类。 空类中真的什么都没有吗&#xff1f;并不是&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以下6个默认成员函数。 默认成员函数&#xff1a;用户没有显式实现&#xff0c;编译器…

Python之线程Thread(一)

一、什么是线程 线程(Thread)特点: 线程(Thread)是操作系统能够进行运算调度的最小单位。它被包含在进程之中,是进程中的实际运作单位线程是程序执行的最小单位,而进程是操作系统分配资源的最小单位;一个进程由一个或多个线程组成,线程是一个进程中代码的不同执行路线;…

elasticsearch的索引库操作

索引库就类似数据库表&#xff0c;mapping映射就类似表的结构。我们要向es中存储数据&#xff0c;必须先创建“库”和“表”。 mapping映射属性 mapping是对索引库中文档的约束&#xff0c;常见的mapping属性包括&#xff1a; type&#xff1a;字段数据类型&#xff0c;常见的…

【C++漂流记】一文搞懂类与对象中的对象特征

在C中&#xff0c;类与对象是面向对象编程的基本概念。类是一种抽象的数据类型&#xff0c;用于描述对象的属性和行为。而对象则是类的实例&#xff0c;具体化了类的属性和行为。本文将介绍C中类与对象的对象特征&#xff0c;并重点讨论了对象的引用。 文章目录 一、构造函数和…

Python入门教程35:使用email模块发送HTML和图片邮件

smtplib模块实现邮件的发送功能&#xff0c;模拟一个stmp客户端&#xff0c;通过与smtp服务器交互来实现邮件发送的功能&#xff0c;可以理解成Foxmail的发邮件功能&#xff0c;在使用之前我们需要准备smtp服务器主机地址、邮箱账号以及密码信息。 #我的Python教程 #官方微信公…

什么是 DNS 隧道以及如何检测和防止攻击

什么是 DNS 隧道&#xff1f; DNS 隧道是一种DNS 攻击技术&#xff0c;涉及在 DNS 查询和响应中对其他协议或程序的信息进行编码。DNS 隧道通常具有可以锁定目标 DNS 服务器的数据有效负载&#xff0c;允许攻击者管理应用程序和远程服务器。 DNS 隧道往往依赖于受感染系统的…

sklearn中的数据集使用

导库 from sklearn.datasets import load_iris 实现 # 加载数据集 iris load_iris() print(f查看数据集&#xff1a;{iris}) print(f查看数据集的特征&#xff1a;{iris.feature_names}) print(f查看数据集的标签&#xff1a;{iris.target_names}) print(f查看数据集的描述…

linux 安装Docker

# 1、yum 包更新到最新 yum update # 2、安装需要的软件包&#xff0c; yum-util 提供yum-config-manager功能&#xff0c;另外两个是devicemapper驱动依赖的 yum install -y yum-utils device-mapper-persistent-data lvm2 # 3、 设置yum源 yum-config-manager --add-repo h…

Lua01——概述

Lua是啥&#xff1f; 官网 https://www.lua.org Lua这个名字在葡萄牙语中的意思是“美丽的月亮”&#xff0c;诞生于巴西的大学实验室。 这是一个小巧、高效且能够很好的和C语言一起工作的编程语言。 在脚本语言领域中&#xff0c;Lua因为有资格作为游戏开发的备选方案&…