【图像去噪】论文复现:代替ReLU!Pytorch实现即插即用激活函数模块xUnit,并插入到DnCNN中实现xDnCNN!

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中)

本文亮点:

  • 实现三种xUnit模块:xUnit(论文默认)、xUnitS(轻量化)、xUnitD(密集型),随取随用
  • xUnit模块可以加入到任意去噪模型中,替代ReLU激活函数
  • 测试结果与论文中所述观点基本一致

文章目录

  • 前言
  • 一、xUnit结构实现
  • 二、xDnCNN结构实现
  • 三、结果展示


前言

论文题目:xUnit: Learning a Spatial Activation Function for Efficient Image Restoration —— xUnit:学习空间激活函数进行高效图像恢复

论文地址:xUnit: Learning a Spatial Activation Function for Efficient Image Restoration

论文源码:https://github.com/kligvasser/xUnit

对应的论文精读:【图像去噪】论文精读:xUnit: Learning a Spatial Activation Function for Efficient Image Restoration

只需要源码中的xUnit结构实现,并不需要其他的。本文将xUnit模块插入到DnCNN中实现xDnCNN。

一、xUnit结构实现

xUnit结构回顾:
在这里插入图片描述

  • xUnit默认结构:BN+RL+CD+BN+GS
  • xUnit轻量结构:CD+BN+GS
  • xUnit密集结构:CD+BN+RL+CD+BN+GS

代码实现如下,命名为activations.py:

import torch.nn as nnclass xUnit(nn.Module):def __init__(self, num_features=64, kernel_size=9, batch_norm=False):super(xUnit, self).__init__()# xUnitself.features = nn.Sequential(nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity,nn.ReLU(),nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity,nn.Sigmoid())def forward(self, x):a = self.features(x)r = x * areturn rclass xUnitS(nn.Module):def __init__(self, num_features=64, kernel_size=9, batch_norm=False):super(xUnitS, self).__init__()# slim xUnitself.features = nn.Sequential(nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),nn.Sigmoid())def forward(self, x):a = self.features(x)r = x * areturn rclass xUnitD(nn.Module):def __init__(self, num_features=64, kernel_size=9, batch_norm=False):super(xUnitD, self).__init__()# dense xUnitself.features = nn.Sequential(nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=1, padding=0),nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),nn.ReLU(),nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),nn.Sigmoid())def forward(self, x):a = self.features(x)r = x * areturn rclass Identity(nn.Module):def __init__(self,):super(Identity, self).__init__()def forward(self, x):return x

二、xDnCNN结构实现

先回顾以下DnCNN的网络结构:
在这里插入图片描述
代码实现如下:

class DnCNN(nn.Module):def __init__(self, num_layers=17, num_features=64):super(DnCNN, self).__init__()layers = [nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True))]for i in range(num_layers - 2):layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),nn.BatchNorm2d(num_features),nn.ReLU(inplace=True)))layers.append(nn.Conv2d(num_features, 3, kernel_size=3, padding=1))self.layers = nn.Sequential(*layers)self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)def forward(self, inputs):y = inputsresidual = self.layers(y)return y - residual

上面是一个17层的DnCNN实现,使用xUnit代替DnCNN中的ReLU,同时减少卷积层数为9层,称作xDnCNN。

实现如下:

from torch import nn
from activations import xUnit, xUnitD, xUnitSclass xDnCNN(nn.Module):def __init__(self, num_layers=9, num_features=64):super(xDnCNN, self).__init__()layers = [nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1),xUnit(num_features, batch_norm=True))]for i in range(num_layers - 2):layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),nn.BatchNorm2d(num_features),xUnit(num_features, batch_norm=True)))layers.append(nn.Conv2d(num_features, 3, kernel_size=3, padding=1))self.layers = nn.Sequential(*layers)self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)def forward(self, inputs):y = inputsresidual = self.layers(y)return y - residual

至此,xUnit模块嵌入了模型中,除了DnCNN,其他有ReLU激活的模型都可以以此方法替代。

DnCNN于xDnCNN结构对比:

  • DnCNN:[Conv+ReLU]+15个[Conv+BN+ReLU]+Conv,共17个卷积层,16个ReLU,15个BN
  • xDnCNN:[Conv+xUnit(BN+ReLU+Conv+BN+Sigmoid)]+7个[Conv+BN+xUnit(BN+ReLU+Conv+BN+Sigmoid)] + Conv,共17个卷积层,8个ReLU,23个BN

区别为:xDnCNN少了8个ReLU,多了8个BN,并且xUnit中的Conv卷积核为9×9,而其他均为3×3。

虽然模型性能区别不能简单地以模块数量多少而论,但也能从中发现一些端倪。

  • 卷积层个数相同,卷积核越大,参数不一定越多。也受整体层数影响,卷积核更大,整体层数更少,虽然在单层的参数量更多,但总体的参数量更少。
  • 本质上,是增大感受野以增强特征提取能力(论文图5、图9)。只是套了这么一个xUnit的壳,实际上就是改变结构,只不过把这一堆统称为xUnit。
  • 给我们的启示:尝试把一堆组件绑在一起作为一个整体,调整其中的某个参数(e.g.卷积核,就可以减少整体层数了),看看能不能有所提升。

xUnit的作用:减少模型参数、性能几乎不变、纹理细节提升!

三、结果展示

性能对比:

Methods DnCNN-S xDnCNN
parameters559363(559K)29.08
σ=50306947(307K)29.03

视觉展示(论文图7):

在这里插入图片描述


至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/411074.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文章生成用这三款伪原创软件效果好

在当今信息爆炸的时代,无论是网站运营者、博主、作家还是学生,对文章的需求量越来越大。他们需要用大理的的原创文章来满足他们工作需求。然而,对于许多人来说,写作一篇优质的文章并非易事。这就产生了一种需求,那就是…

3 Python开发工具:VSCode+插件

本文是 Python 系列教程第 3 篇,完整系列请查看 Python 专栏。 Visual Studio Code的安装非常简单,就不放这里增加文章篇幅了。 相比PyCharm,VSCode更加轻量,启动速度快。并且搭配Python插件就能实现和Pycharm一样的代码提示、高…

如何将平淡无奇的产品推向市场?借助ChatGPT,仅需3秒即可化身短视频创意策划大师,助你的产品一夜成名!

毫无趣味的产品要如何宣传?用ChatGPT,3秒钟成为创意短视频策划高手,让你的产品出圈!© 由 ZAKER 提供 最近,全红婵最爱的小乌龟火了。 制作小乌龟的某位义乌商家在接受采访时,表示自己有了甜蜜的烦恼…

力扣刷题(2)

寻找两个正序数组的中位数 寻找两个正序数组的中位数-力扣 思路: 合并两个正序数组找中位数 double findMedianSortedArrays(int* nums1, int nums1Size, int* nums2, int nums2Size) {int arr[nums1Size nums2Size];int n1 0, n2 0;int m 0;int q;//合并两个正序数组w…

Git 远程操作

1. 理解分布式版本控制系统 我们所说的⼯作区,暂存区,版本库等,都是在本地!也就是在笔记本或计算机上。⽽我们的 Git 其实是分布式版本控制系统.可以简单理解为,我们每个⼈的电脑上都是⼀个完整的版本库,这…

Java 中的抽象工厂模式:优雅地掌握对象创建

文章目录 一、概述三、抽象工厂设计模式的意图四、抽象工厂模式的详细解释及实际示例五、Java 中抽象工厂模式的编程示例六、抽象工厂模式类图七、Java 中何时使用抽象工厂模式八、抽象工厂模式 Java 教程九、抽象工厂模式的优点和权衡十、Java 中抽象工厂模式的实际应用十一、…

【Web UI自动化测试】Web UI自动化测试之框架篇(全网最全)

本文大纲截图: UnitTest框架: PyTest框架: 框架: 框架英文单词 framework,为解决一类事情的功能的集合。需要按照框架的规定(套路)去书写代码。 一、UnitTest框架介绍【文末分享自动化测试学…

使用canal增量同步ES索引库数据

Canal增量数据同步利器 Canal介绍 canal主要用途是基于 MySQL 数据库增量日志解析,并能提供增量数据订阅和消费,应用场景十分丰富。 github地址:https://github.com/alibaba/canal 版本下载地址:https://github.com/alibaba/c…

鸿蒙开发:深入浅出Stage模型(UIAbility组件)

🚀一、UIAbility组件 🔎1.概述 HarmonyOS中的Stage模型是一种基于UIAbility组件的应用程序架构。UIAbility是HarmonyOS系统中用于构建用户界面的基本组件之一。它负责处理应用程序界面的显示和交互。 在Stage模型中,每个应用程序都有一个或…

LLM —— 强化学习(RLHF-PPO和DPO)学习笔记

强化学习整体流程 智能体执行动作与环境进行交互,根据奖励R的反馈结果不断进行更新。 价值函数 奖励将会考虑两个方面的奖励,一个当下的奖励,一个是未来的奖励(为了防止陷入局部最优解)。 LLM强化学习 强化学习模型分…

CTF—杂项学习

1 文件操作隐写 1.1 文件类型识别 1.1.1 File命令 当文件没有后缀名或有后缀名而无法打开时,根据识别出的文件类型来修改后缀名即可正常打开文件,file是Linux下的文件识别命令。 file 文件名 使用场景:不知道后缀名,无法打开文件…

【STM32开发笔记】STM32H7S78-DK上的CoreMark移植和优化--兼记STM32上的printf重定向实现及常见问题解决

【STM32开发笔记】STM32H7S78-DK上的CoreMark移植和优化--兼记STM32上的printf重定向实现及常见问题解决 一、CoreMark简介二、创建CubeMX项目2.1 选择MCU2.2 配置CPU时钟2.3 配置串口功能2.4 配置LED引脚2.5 生成CMake项目 三、基础功能支持3.1 支持记录耗时3.2 支持printf输出…

SEO之网站结构优化(十三-网站地图)

** 初创企业搭建网站的朋友看1号文章;想学习云计算,怎么入门看2号文章谢谢支持: ** 1、我给不会敲代码又想搭建网站的人建议 2、“新手上云”能够为你开启探索云世界的第一步 博客:阿幸SEO~探索搜索排名之道 网站无论大小&…

京存分布式赋能EDA应用

合抱之木,生于毫末;九层之台,起于累土;千里之行,始于足下。——《老子德经第六十四章》 EDA(Electronic Design Automation 电子设计自动化)是利用计算机,完成对VLSI (V…

OpenCV绘图函数(8)填充凸多边形函数fillConvexPoly()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 填充一个凸多边形。 函数 cv::fillConvexPoly 绘制一个填充的凸多边形。这个函数比 fillPoly 函数快得多。它可以填充的不仅仅是凸多边形&#…

25届最近5年自动化考研院校分析

哈尔滨工程大学 目录 一、学校学院专业简介 二、考试科目指定教材 三、近5年考研分数情况 四、近5年招生录取情况 五、最新一年分数段图表 六、初试大纲复试大纲 七、学费&奖学金&就业方向 一、学校学院专业简介 二、考试科目指定教材 1、考试科目介绍 2、指定…

C++ | Leetcode C++题解之第377题组合总和IV

题目&#xff1a; 题解&#xff1a; class Solution { public:int combinationSum4(vector<int>& nums, int target) {vector<int> dp(target 1);dp[0] 1;for (int i 1; i < target; i) {for (int& num : nums) {if (num < i && dp[i - …

《JavaEE进阶》----4.<SpringMVC①简介、基本操作>

本篇博客讲解 MVC思想、及Spring MVC&#xff08;是对MVC思想的一种实现&#xff09;。 Spring MVC的基本操作、学习了六个注解 RestController注解 RequestMappering注解 RequestParam注解 RequestBody注解 PathVariable注解 RequestPart注解 MVC View(视图) 指在应⽤程序中…

四大名著改编的ip大作,一个巨亏2亿,一个狂赚20亿!选择决定成败!

最近讨论热度比较高的当属《红楼梦》和《西游记》了 胡玫导演的《红楼梦之金玉良缘》耗费了18年的心血&#xff0c;投资了2个多亿 却仅仅只有600万票房&#xff0c;还被网友调侃称“一黛不如一黛” 而由《西游记》改编的游戏《黑神话悟空》&#xff0c;研发10年投资6亿&…

【drools】Rulesengine构建及intelj配置

7.57.0.FinalRulesengineApplication 使用maven构建 intelj 打开文件资源管理器实在是太慢了所以直接把pom 扔到其主页识别为maven项目,自动下载maven包管理器 然后解析依赖: 给maven加一个代理 -DproxyHost=127.0.0.1 -DproxyPort=7890 还是卡主