AIRNet模型使用与代码分析(All-In-One Image Restoration Network)

AIRNet提出了一种较为简易的pipeline,以单一网络结构应对多种任务需求(不同类型,不同程度)。但在效果上看,ALL-In-One是不如One-By-One的,且本文方法的亮点是batch内选择patch进行对比学习。在与sota对比上,仅是Denoise任务精度占优,在Derain与Dehaze任务上,效果不如One-By-One的MPRNet方法。本博客对AIRNet的关键结构实现,loss实现,data_patch实现进行深入分析,并对模型进行推理使用。

其论文的详细可以阅读:https://blog.csdn.net/a486259/article/details/139559389?spm=1001.2014.3001.5501

项目地址:https://blog.csdn.net/a486259/article/details/139559389?spm=1001.2014.3001.5501

项目依赖:torch、mmcv-full
安装mmcv-full时,需要注意torch所对应的cuda版本,要与系统中的cuda版本一致。

1、模型结构

AirNet的网络结构如下所示,输入图像x交由CBDE提取到嵌入空间z,z与x输入到DGRN模块的DGG block中逐步优化,最终输出预测结果。
在这里插入图片描述
模型代码在net\model.py

from torch import nnfrom net.encoder import CBDE
from net.DGRN import DGRNclass AirNet(nn.Module):def __init__(self, opt):super(AirNet, self).__init__()# Encoderself.E = CBDE(opt)  #编码特征值# Restorerself.R = DGRN(opt) #特征解码def forward(self, x_query, x_key):if self.training:fea, logits, labels, inter = self.E(x_query, x_key)restored = self.R(x_query, inter)return restored, logits, labelselse:fea, inter = self.E(x_query, x_query)restored = self.R(x_query, inter)return restored

1.1 CBDE模块

CBDE模块的功能是在模块内进行对比学习,核心是MoCo. Moco论文地址:https://arxiv.org/pdf/1911.05722

class CBDE(nn.Module):def __init__(self, opt):super(CBDE, self).__init__()dim = 256# Encoderself.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim)def forward(self, x_query, x_key):if self.training:# degradation-aware represenetion learningfea, logits, labels, inter = self.E(x_query, x_key)return fea, logits, labels, interelse:# degradation-aware represenetion learningfea, inter = self.E(x_query, x_query)return fea, inter

ResEncoder所对应的网络结构如下所示
在这里插入图片描述

在AIRNet中的CBDE模块里的MoCo模块的关键代码如下,其在内部自行完成了正负样本的分配,最终输出logits, labels用于计算对比损失的loss。但其所优化的模块实际上是ResEncoder。MoCo模块只是在训练阶段起作用,在推理阶段是不起作用的。

class MoCo(nn.Module):def forward(self, im_q, im_k):"""Input:im_q: a batch of query imagesim_k: a batch of key imagesOutput:logits, targets"""if self.training:# compute query featuresembedding, q, inter = self.encoder_q(im_q)  # queries: NxCq = nn.functional.normalize(q, dim=1)# compute key featureswith torch.no_grad():  # no gradient to keysself._momentum_update_key_encoder()  # update the key encoder_, k, _ = self.encoder_k(im_k)  # keys: NxCk = nn.functional.normalize(k, dim=1)# compute logits# Einstein sum is more intuitive# positive logits: Nx1l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)# negative logits: NxKl_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)# apply temperaturelogits /= self.T# labels: positive key indicatorslabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()# dequeue and enqueueself._dequeue_and_enqueue(k)return embedding, logits, labels, interelse:embedding, _, inter = self.encoder_q(im_q)return embedding, inter

1.2 DGRN模块

DGRN模块的实现代码如下所示,可以看到核心是DGG模块,其不断迭代优化输入图像。

class DGRN(nn.Module):def __init__(self, opt, conv=default_conv):super(DGRN, self).__init__()self.n_groups = 5n_blocks = 5n_feats = 64kernel_size = 3# head modulemodules_head = [conv(3, n_feats, kernel_size)]self.head = nn.Sequential(*modules_head)# bodymodules_body = [DGG(default_conv, n_feats, kernel_size, n_blocks) \for _ in range(self.n_groups)]modules_body.append(conv(n_feats, n_feats, kernel_size))self.body = nn.Sequential(*modules_body)# tailmodules_tail = [conv(n_feats, 3, kernel_size)]self.tail = nn.Sequential(*modules_tail)def forward(self, x, inter):# headx = self.head(x)# bodyres = xfor i in range(self.n_groups):res = self.body[i](res, inter)res = self.body[-1](res)res = res + x# tailx = self.tail(res)return x

在这里插入图片描述
DGG模块的结构示意如下所示
在这里插入图片描述
DGG代码实现如下所示,DGG模块内嵌DGB模块,DGB模块内嵌DGM模块,DGM模块内嵌SFT_layer模块与DCN_layer(可变性卷积)
在这里插入图片描述

2、loss实现

AIRNet中提到的loss如下所示,其中Lrec是L1 loss,Lcl是Moco模块实现的对比损失。
在这里插入图片描述
AIRNet的loss实现代码在train.py中,CE loss是针对CBDE(Moco模块)的输出进行计算,l1 loss是针对修复图像与清晰图片。

    # Network Constructionnet = AirNet(opt).cuda()net.train()# Optimizer and Lossoptimizer = optim.Adam(net.parameters(), lr=opt.lr)CE = nn.CrossEntropyLoss().cuda()l1 = nn.L1Loss().cuda()# Start trainingprint('Start training...')for epoch in range(opt.epochs):for ([clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2) in tqdm(trainloader):degrad_patch_1, degrad_patch_2 = degrad_patch_1.cuda(), degrad_patch_2.cuda()clean_patch_1, clean_patch_2 = clean_patch_1.cuda(), clean_patch_2.cuda()optimizer.zero_grad()if epoch < opt.epochs_encoder:_, output, target, _ = net.E(x_query=degrad_patch_1, x_key=degrad_patch_2)contrast_loss = CE(output, target)loss = contrast_losselse:restored, output, target = net(x_query=degrad_patch_1, x_key=degrad_patch_2)contrast_loss = CE(output, target)l1_loss = l1(restored, clean_patch_1)loss = l1_loss + 0.1 * contrast_loss# backwardloss.backward()optimizer.step()

这里可以看出来,AIRNet首先是训练CBDE模块,最后才训练CBDE模块+DGRN模块。

3、TrainDataset

TrainDataset的实现代码在utils\dataset_utils.py中,首先找到__getitem__函数进行分析。以下代码为关键部分,删除了大部分在逻辑上重复的部分。TrainDataset一共支持5种数据类型,‘denoise_15’: 0, ‘denoise_25’: 1, ‘denoise_50’: 2,是不需要图像对的(在代码里面自动对图像添加噪声);‘derain’: 3, ‘dehaze’: 4是需要图像对进行训练的。

class TrainDataset(Dataset):def __init__(self, args):super(TrainDataset, self).__init__()self.args = argsself.rs_ids = []self.hazy_ids = []self.D = Degradation(args)self.de_temp = 0self.de_type = self.args.de_typeself.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4}self._init_ids()self.crop_transform = Compose([ToPILImage(),RandomCrop(args.patch_size),])self.toTensor = ToTensor()def __getitem__(self, _):de_id = self.de_dict[self.de_type[self.de_temp]]if de_id < 3:if de_id == 0:clean_id = self.s15_ids[self.s15_counter]self.s15_counter = (self.s15_counter + 1) % self.num_cleanif self.s15_counter == 0:random.shuffle(self.s15_ids)# clean_id = random.randint(0, len(self.clean_ids) - 1)clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2)# clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]clean_name = clean_id.split("/")[-1].split('.')[0]clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2)degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2)self.de_temp = (self.de_temp + 1) % len(self.de_type)if self.de_temp == 0:random.shuffle(self.de_type)return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2

可以看出TrainDataset返回的数据有:degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2。

3.1 clean_patch分析

通过以下代码可以看出 clean_patch_1, clean_patch_2是来自于同一个图片,然后基于crop_transform变化,变成了2个对象

            clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)# clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]clean_name = clean_id.split("/")[-1].split('.')[0]clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)

crop_transform的定义如下,可见是随机进行crop

crop_transform = Compose([ToPILImage(),RandomCrop(args.patch_size),])

random_augmentation的实现代码如下,可以看到只是随机对图像进行翻转或旋转,其目的是尽可能使随机crop得到clean_patch_1, clean_patch_2差异更大,避免裁剪出高度相似的patch。

def random_augmentation(*args):out = []flag_aug = random.randint(1, 7)for data in args:out.append(data_augmentation(data, flag_aug).copy())return out
def data_augmentation(image, mode):if mode == 0:# originalout = image.numpy()elif mode == 1:# flip up and downout = np.flipud(image)elif mode == 2:# rotate counterwise 90 degreeout = np.rot90(image)elif mode == 3:# rotate 90 degree and flip up and downout = np.rot90(image)out = np.flipud(out)elif mode == 4:# rotate 180 degreeout = np.rot90(image, k=2)elif mode == 5:# rotate 180 degree and flipout = np.rot90(image, k=2)out = np.flipud(out)elif mode == 6:# rotate 270 degreeout = np.rot90(image, k=3)elif mode == 7:# rotate 270 degree and flipout = np.rot90(image, k=3)out = np.flipud(out)else:raise Exception('Invalid choice of image transformation')return out

3.2 degrad_patch分析

degrad_patch来自于clean_patch,可以看到是通过D.degrade进行转换的。

degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)

D.degrade相关的代码如下,可以看到只是对图像添加噪声。难怪AIRNet在图像去噪上效果最好。

class Degradation(object):def __init__(self, args):super(Degradation, self).__init__()self.args = argsself.toTensor = ToTensor()self.crop_transform = Compose([ToPILImage(),RandomCrop(args.patch_size),])def _add_gaussian_noise(self, clean_patch, sigma):# noise = torch.randn(*(clean_patch.shape))# clean_patch = self.toTensor(clean_patch)noise = np.random.randn(*clean_patch.shape)noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)# noisy_patch = torch.clamp(clean_patch + noise * sigma, 0, 255).type(torch.int32)return noisy_patch, clean_patchdef _degrade_by_type(self, clean_patch, degrade_type):if degrade_type == 0:# denoise sigma=15degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)elif degrade_type == 1:# denoise sigma=25degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)elif degrade_type == 2:# denoise sigma=50degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)return degraded_patch, clean_patchdef degrade(self, clean_patch_1, clean_patch_2, degrade_type=None):if degrade_type == None:degrade_type = random.randint(0, 3)else:degrade_type = degrade_typedegrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type)degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type)return degrad_patch_1, degrad_patch_2

4、推理演示

项目中默认包含了All.pth,要单独任务的模型可以到预训练模型下载地址: Google Drive and Baidu Netdisk (password: cr7d). 下载模型放到 ckpt/ 目录下

打开demo.py,将 subprocess.check_output(['mkdir', '-p', opt.output_path]) 替换为os.makedirs(opt.output_path,exist_ok=True),避免在window上报错,具体修改如下所示
在这里插入图片描述

demo.py默认从test\demo目录下读取图片进行测试,可见原始图像如下
在这里插入图片描述
代码运行后的输出结果默认保存在 output\demo目录下,可见对于去雨,去雾,去噪声效果都比较好。
在这里插入图片描述
模型推理时间如下所示,可以看到对一张320, 480的图片,要0.54s
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/351343.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React基础教程(07):条件渲染

1 条件渲染 使用条件渲染&#xff0c;结合TodoList案例&#xff0c;进行完善&#xff0c;实现以下功能&#xff1a; 当列表中的数据为空的时候&#xff0c;现实提示信息暂无待办事项当列表中存在数据的时候&#xff0c;提示信息消失 这里介绍三种实现方式。 注意这里的Empty是…

考研计组chap3存储系统

目录 一、存储器的基本概念 80 1.按照层次结构 2.按照各种分类 &#xff08;41&#xff09;存储介质 &#xff08;2&#xff09;存取方式 &#xff08;3&#xff09;内存是否可更改 &#xff08;4&#xff09;信息的可保存性 &#xff08;5&#xff09;读出之后data是否…

HTTP协议 快速入门

http概述 无状态性&#xff1a;HTTP是一个无状态协议&#xff0c;这意味着服务器不会在请求之间保存任何会话信息。每个请求都是独立的&#xff0c;服务器不会记住之前的请求。 请求-响应模型&#xff1a;HTTP通信是基于客户端发送请求和服务器返回响应的模型。客户端&#xf…

Python | 中心极限定理介绍及实现

统计学是数据科学项目的重要组成部分。每当我们想从数据集的样本中对数据集的总体进行任何推断&#xff0c;从数据集中收集信息&#xff0c;或者对数据集的参数进行任何假设时&#xff0c;我们都会使用统计工具。 中心极限定理 定义&#xff1a;中心极限定理&#xff0c;通俗…

Android低代码开发 - InputMenuPanelItem详解

我们知道MenuPanel是一个菜单面板容器&#xff0c;它里面可以放各式各样的菜单和菜单组。今天我们就来详细讲解输入菜单这个东西。 InputMenuPanelItem源码 package dora.widget.panel.menuimport android.content.Context import android.text.Editable import android.text…

系统集成项目管理工程师第9章思维导图发布

今天发布系统集成项目管理工程师新版第9章脑图的图片版

Nintex流程平台引入生成式人工智能,实现自动化革新

工作流自动化提供商Nintex宣布在其Nintex流程平台上推出一系列新的人工智能驱动改进。这些增强显著减少了文档化、管理和自动化业务流程所需的时间。这些新特性为Nintex流程平台不断扩展的人工智能能力增添了新的亮点。 Nintex首席产品官Niranjan Vijayaragavan表示&#xff1a…

使用React和GraphQL进行CRUD:完整教程与示例

在本教程中&#xff0c;我们将向您展示如何使用GraphQL和React实现简单的端到端CRUD操作。我们将介绍使用React Hooks读取和修改数据的简单示例。我们还将演示如何使用Apollo Client实现身份验证、错误处理、缓存和乐观UI。 什么是React&#xff1f; React是一个用于构建用户…

用python纯手写一个日历

一、代码 # 月份名称数组 months ["January", "February", "March", "April", "May", "June","July", "August", "September", "October", "November", &qu…

【Python/Pytorch - 网络模型】-- TV Loss损失函数

文章目录 文章目录 00 写在前面01 基于Pytorch版本的TV Loss代码02 论文下载 00 写在前面 在医学图像重建过程中&#xff0c;经常在代价方程中加入TV 正则项&#xff0c;该正则项作为去噪项&#xff0c;对于重建可以起到很大帮助作用。但是对于一些纹理细节要求较高的任务&am…

20.1 JSON-JSON接口以及在Go语言中使用JSON

1. 简介 JSON即JavaScript对象表示法(JavaScript Object Notation)&#xff0c;是一种用于存储和交换数据的格式&#xff0c;是一种可供人类阅读和理解的纯文本格式。 JSON既可以键值对的形式&#xff0c;也可以数组的形式&#xff0c;表示数据。 JSON最初是JavaScript的一个…

流媒体传输协议HTTP-FLV、WebSocket-FLV、HTTP-TS 和 WebSocket-TS的详细介绍、应用场景及对比

一、前言 HTTP-FLV、WS-FLV、HTTP-TS 和 WS-TS 是针对 FLV 和 TS 格式视频流的不同传输方式。它们通过不同的协议实现视频流的传输&#xff0c;以满足不同的应用场景和需求。接下来我们对这些流媒体传输协议进行剖析。 二、传输协议 1、HTTP-FLV 介绍&#xff1a;基于 HTTP…

【宠粉赠书】科研绘图神器:MATLAB科技绘图与数据分析

小智送书第二期~ 为了回馈粉丝们的厚爱&#xff0c;今天小智给大家送上一套科研绘图的必备书籍——MATLAB科技绘图与数据分析。下面我会详细给大家介绍这套图书&#xff0c;文末留有领取方式。 图书介绍 《MATLAB科技绘图与数据分析》是一本综合性强、内容丰富的书籍&#x…

实践分享:鸿蒙跨平台开发实例

先来理解什么是跨平台 提到跨平台&#xff0c;要先理解什么是“平台”&#xff0c;这里的平台&#xff0c;就是指应用程序的运行环境&#xff0c;例如操作系统&#xff0c;或者是Web浏览器&#xff0c;具体的像HarmonyOS、Android、iOS、或者浏览器&#xff0c;都可以叫做平台…

Python读取wps中的DISPIMG图片格式

需求&#xff1a; 读出excel的图片内容&#xff0c;这放在微软三件套是很容易的&#xff0c;但是由于wps的固有格式&#xff0c;会出现奇怪的问题&#xff0c;只能读出&#xff1a;类似于 DISPIMG(“ID_2B83F9717AE1XXXX920xxxx644C80DB1”,1) 【该DISPIMG函数只有wps才拥有】 …

阿里云服务器-Linux搭建fastDFS文件服务器

阿里云官网购买服务器&#xff0c;一般会有降价活动&#xff0c;这两天就发现有活动&#xff0c;99计划活动&#xff08;在活动期内&#xff0c;续费都是99元&#xff09; 阿里云官网-云服务器ECS 在这里&#xff0c;我购买了这台服务器&#xff0c;活动期内续费每年99元&…

javaweb 期末复习

1. JDBC数据库连接的实现逻辑与步骤以及JDBC连接配置&#xff08;单列模式&#xff09; public class JDBCUtil {// 这些换成自己的数据库 private static final String DB_URL "jdbc:mysql://localhost:3306/你的数据库名称";private static final String USER &q…

10分钟部署一个个人博客

关于vuepress这里没必要过多介绍&#xff0c;感兴趣的可以直接去官网了解&#xff0c;下面是官网首页地址截图 &#xff1a;https://v2.vuepress.vuejs.org/zh/ 透过这张图&#xff0c;我们也可以大致的对这个框架的特点有一定的认识&#xff0c;这就够了。其他的东西我们在使用…

vue3+ Element-Plus 点击勾选框往input中动态添加多个tag

实现效果&#xff1a; template&#xff1a; <!--产品白名单--><div class"con-item" v-if"current 0"><el-form-item label"平台名称"><div class"contaion" click"onclick"><!-- 生成的标签 …

WPF界面设计

1、使用C#-WPF实现抽屉效果-炫酷漂亮的侧边栏导航菜单-SplitViewMD主题重绘原生控件的美观效果-提供源码Demo下载 码源地址&#xff1a;https://download.csdn.net/download/Prince999999/89424685 2、使用C#-WPF实现抽屉效果-菜单导航功能实现&#xff0c;常规的管理系统应该…