视频去噪网络BSVD的实现

前些天写了视频去噪网络BSVD论文的理解,详情请点击这里,这两个星期动手实践了一下,本篇就来记录一下这个模型的实现。

这个网络的独特之处在于,它的训练和推理在实现上有所差别。在训练阶段,其使用了TSM(Time Shift Module)结构,而在推理时则使用了BBB(Bidirectional Buffer Block)结构。训练时,网络是一个MIMO(多输入多输出)形式,而在推理时,则将其设计成了单输入、单输出的流式形式。推理时,由于网络中存在16个双向buffer,即BBB,因此,前16帧会输出空数据,16帧之后开始正常输出去噪视频帧,到视频序列结束后,还会继续输出16帧的去噪视频帧,也就是,流式推理整体存在16帧的延迟。这在一些对实时性要求不太高的应用中可以推广,但对于实时性要求严格,并且存储资源有限的应用中,就无法有效应用了。

下面,我们就通过对官方代码的理解,来聊一聊BSVD的实现。

官方代码地址:GitHub - ChenyangQiQi/BSVD: [ACM MM 2022] Real-time Streaming Video Denoising with Bidirectional Buffers

BSVD网络采用了两个UNet级联的方式。

1. 训练阶段的网络实现

在训练阶段,网络的实现如下:

class WNet(nn.Module):def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False):# def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', blind=False):super(WNet, self).__init__()self.stage_num = stage_numself.nets_list = nn.ModuleList()for i in np.arange(stage_num):if i == 0:stage_in_ch = in_chelse:stage_in_ch = mid_chif i == (stage_num-1):stage_out_ch = out_chelse:stage_out_ch = mid_ch# self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))if i == 0:self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch))else:self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch,in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))# self.temp2 = DenBlock(chns=chns, in_ch=mid_ch, shift_input=shift_input)# Init weightsself.reset_params()@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def forward(self, x, debug=False):# if debug: x_in = x# x = self.temp1(x)for i in np.arange(self.stage_num):if debug: x_temp1 = xx = self.nets_list[i](x)# if debug: x_temp2 = xreturn x

网络由两个DenBlock组成,每个DenBlock是一个UNet结构:


class DenBlock(nn.Module):""" Definition of the denosing block of FastDVDnet.Inputs of constructor:num_input_frames: int. number of input framesInputs of forward():xn: input frames of dim [N, C, H, W], (C=3 RGB)noise_map: array with noise map of dim [N, 1, H, W]"""def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', interm_ch=30, blind=False):# def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', blind=False):super(DenBlock, self).__init__()self.chs_lyr0, self.chs_lyr1, self.chs_lyr2 = chns# if stage2: in_ch=3if shift_input:self.inc = CvBlock(in_ch=in_ch, out_ch=self.chs_lyr0, norm=norm, bias=bias, act=act)else:self.inc = InputCvBlock(num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, interm_ch=interm_ch, blind=blind)# num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, blind=blind)self.downc0 = DownBlock(in_ch=self.chs_lyr0, out_ch=self.chs_lyr1, norm=norm, bias=bias, act=act)self.downc1 = DownBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr2, norm=norm, bias=bias, act=act)self.upc2 = UpBlock(in_ch=self.chs_lyr2, out_ch=self.chs_lyr1, norm=norm, bias=bias,    act=act)self.upc1 = UpBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr0, norm=norm, bias=bias,    act=act)self.outc = OutputCvBlock(in_ch=self.chs_lyr0, out_ch=out_ch, norm=norm, bias=bias,     act=act)self.reset_params()@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def forward(self, in1):'''Args:inX: Tensor, [N, C, H, W] in the [0., 1.] rangenoise_map: Tensor [N, 1, H, W] in the [0., 1.] range'''# Input convolution blockx0 = self.inc(in1)# Downsamplingx1 = self.downc0(x0)x2 = self.downc1(x1)# Upsamplingx2 = self.upc2(x2)x1 = self.upc1(x1+x2)# Estimationx = self.outc(x0+x1)# Residualx[:, :3, :, :] = in1[:, :3, :, :] - x[:, :3, :, :]return x

这段代码与论文中的UNet结构相对应(见下图),包含一个输入层,两个下采样层,两个上采样层,一个输出层。

输入层没什么特别可说的,主要是两个Conv2d=>BN=>ReLU的组合;输出层也是常规实现,Con2d=>BN=>ReLU=>Con2d,需要注意的是,作者在实现过程中,BN层是没有使用的,是透传通过。

需要花心思理解的是下采样层和上采样层的实现,因为这两个模块在训练和推理过程中,是有所不同的。

两个模块的初始实现很简单,定义如下:

class DownBlock(nn.Module):'''Downscale + (Conv2d => BN => ReLU)*2'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(DownBlock, self).__init__()norm_fn = get_norm_function(norm)act_fn = get_act_function(act)self.convblock = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3,padding=1, stride=2, bias=bias),norm_fn(out_ch),act_fn(inplace=True),CvBlock(out_ch, out_ch, norm=norm, bias=bias, act=act))def forward(self, x):return self.convblock(x)class UpBlock(nn.Module):'''(Conv2d => BN => ReLU)*2 + Upscale'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(UpBlock, self).__init__()# norm_fn = get_norm_function(norm)self.convblock = nn.Sequential(CvBlock(in_ch, in_ch, norm=norm, bias=bias, act=act),nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1, bias=bias),nn.PixelShuffle(2))return self.convblock(x)

关键在于两者共同调用的子模块CvBlock的实现,在定义时,CvBlock被常规定义为:

class CvBlock(nn.Module):'''(Conv2d => BN => ReLU) x 2'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(CvBlock, self).__init__()norm_fn = get_norm_function(norm)act_fn = get_act_function(act)self.c1 = nn.Conv2d(in_ch, out_ch, kernel_size=3,padding=1, bias=bias)self.b1 = norm_fn(out_ch)self.relu1 = act_fn(inplace=True)self.c2 = nn.Conv2d(out_ch, out_ch, kernel_size=3,padding=1, bias=bias)self.b2 = norm_fn(out_ch)self.relu2 = act_fn(inplace=True)def forward(self, x):x = self.c1(x)x = self.b1(x)x = self.relu1(x)x = self.c2(x)x = self.b2(x)x = self.relu2(x)return x

但接下来,上述定义中的c1和c2则被替换成了TSM实现:

其中,shift模块的核心实现代码如下,对输入的channels分别向左和向右移动了一定单位(fold)。

def shift(x, n_segment, shift_type, fold_div=3, stride=1, inplace=False):nt, c, h, w = x.size()n_batch = nt // n_segmentx = x.view(n_batch, n_segment, c, h, w)fold = c // fold_div # 32/8 = 4if inplace:# Due to some out of order error when performing parallel computing. # May need to write a CUDA kernel.print("WARNING: use inplace shift. it has bugs")raise NotImplementedError  else:out = torch.zeros_like(x)if not 'toFutureOnly' in shift_type:out[:, :-stride, :fold] = x[:, stride:, :fold]  # backward (left shift)out[:, stride:, fold: 2 * fold] = x[:, :-stride, fold: 2 * fold]  # forward (right shift)else:out[:, stride:, : 2 * fold] = x[:, :-stride, : 2 * fold] # right shift onlyout[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shiftreturn out.view(nt, c, h, w)

2. 推理阶段的网络实现

在推理阶段,网络实现就显得复杂一些了。大致的网络结构没变,但由于内部的TSM替换成了BBB, 因此没办法严格进行整体网络的加载,只能每一层单独加载训练出来的state_dict。并且,网络推理变成了流式推理,整个网络的定义显得比较凌乱,结构如下:

class BSVD(nn.Module):"""Bidirection-buffer based framework with pipeline-style inference"""def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False, pretrain_ckpt='./experiments/pretrained_ckpt/bsvd-64.pth'):super(BSVD, self).__init__()self.temp1 = DenBlock(chns=chns, out_ch=mid_ch, in_ch=in_ch,  shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)self.temp2 = DenBlock(chns=chns, out_ch=out_ch, in_ch=mid_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)self.shift_num = self.count_shift()# Init weightsself.reset_params()if pretrain_ckpt is not None:self.load(pretrain_ckpt)def reset(self):self.temp1.reset()self.temp2.reset()def load(self, path):ckpt = torch.load(path)print("load from %s"%path)ckpt_state = ckpt['params']# split the dict hereif 'module' in list(ckpt_state.keys())[0]:base_name = 'module.base_model.'else:base_name = 'base_model.'ckpt_state_1 = extract_dict(ckpt_state, string_name=base_name+'nets_list.0.')ckpt_state_2 = extract_dict(ckpt_state, string_name=base_name+'nets_list.1.')self.temp1.load_from(ckpt_state_1)self.temp2.load_from(ckpt_state_2)@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def feedin_one_element(self, x):x   = self.temp1(x)x   = self.temp2(x)return xdef forward(self, input, noise_map=None):# N, F, C, H, W -> (N*F, C, H, W)if noise_map != None:input = torch.cat([input, noise_map], dim=2)N, F, C, H, W = input.shapeinput = input.reshape(N*F, C, H, W)base_out = self.streaming_forward(input)NF, C, H, W = base_out.shapebase_out = base_out.reshape(N, F, C, H, W)return base_outdef streaming_forward(self, input_seq):"""pipeline-style inferenceArgs:Noisy video streamReturns:Denoised video stream"""out_seq = []if isinstance(input_seq, torch.Tensor):n,c,h,w = input_seq.shapeinput_seq = [input_seq[i:i+1, ...] for i in np.arange(n)]assert type(input_seq) == list, "convert the input into a sequence"_,c,h,w = input_seq[0].shapewith torch.no_grad():for i, x in enumerate(input_seq):x_cuda = x.cuda()x_cuda = self.feedin_one_element(x_cuda)# if x_cuda is not None: x_cuda = x_cuda.cpu()if isinstance(x_cuda, torch.Tensor):out_seq.append(x_cuda)else:out_seq.append(x_cuda)end_out = self.feedin_one_element(None)out_seq.append(end_out)# end stagewhile 1:end_out = self.feedin_one_element(None)if len(out_seq) == (self.shift_num+len(input_seq)):breakout_seq.append(end_out)# number of temporal shift is 2, last element is 0# TODO fix init and end framesout_seq_clip = out_seq[self.shift_num:]self.reset()return torch.cat(out_seq_clip, dim=0)def count_shift(self):count = 0for name, module in self.named_modules():# print(type(module))if "BiBufferConv" in str(type(module)):count+=1return count

两个UNet的定义(DenBlock)大体上没发生变化,但下采样模块和上采样模块的定义发生了改变。

下采样层如下,原来带有TSM的CvBlock换成了MemCvBlock:

上采样模块也类似:

 

而MemCvBlock则调用了BBB模块,BBB模块的实现如下,这是整个算法的核心:

class BiBufferConv(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,bias=True) -> None:super(BiBufferConv, self).__init__()self.op = ShiftConv(in_channels,out_channels,kernel_size,stride,padding,bias)self.out_channels = out_channelsself.left_fold_2fold = None# self.zero_tensor = Noneself.center = Nonedef reset(self):self.left_fold_2fold = Noneself.center = Nonedef forward(self, input_right, verbose=False):fold_div = 8if input_right is not None:self.n, self.c, self.h, self.w = input_right.size()self.fold = self.c//fold_div# Case1: In the start or end stage, the memory is emptyif self.center is None:self.center = input_right# if verbose:if input_right is not None:if self.left_fold_2fold is None:# In the start stage, the memory and left tensor is emptyself.left_fold_2fold = torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda'))if verbose: print("%f+none+%f = none"%(torch.mean(self.left_fold_2fold), torch.mean(input_right)))else:# in the end stage, both feed in and memory are emptyif verbose: print("%f+none+none = none"%(torch.mean(self.left_fold_2fold)))# print("self.center is None")return None# Case2: Center is not None, but input_right is Noneelif input_right is None:# In the last procesing stage, center is 0output =  self.op(self.left_fold_2fold, self.center, torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda')))if verbose: print("%f+%f+none = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(output)))else:output =  self.op(self.left_fold_2fold, self.center, input_right)if verbose: print("%f+%f+%f = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(input_right), torch.mean(output)))# if output == 57:# a = 1self.left_fold_2fold = self.center[:, self.fold:2*self.fold, :, :]self.center = input_rightreturn output

这样,通过BBB模块,就实现了16个双向Buffer的填充、更新和清空。

限于篇幅,先梳理出个大体的思路,实际上还有很多细节需要特别关注,留待下一篇来写吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/169456.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IP地址SSL证书 IP证书

在许多企业用例中,公司需要SSL证书作为IP地址。公司使用IP地址通过Internet访问各种类型的应用程序。 公网IP地址的SSL证书: 内部IP(也称为私有IP)是IANA设置为保存的IPv4或IPv6地址,例如: RFC 1918范围内…

【JavaEE】CAS -- 多线程篇(7)

CAS 1. 什么是 CAS2. CAS 伪代码3. CAS 是怎么实现的4. CAS的应用4.1 实现原子类4.2 实现自旋锁 5. CAS 的 ABA 问题 1. 什么是 CAS CAS: 全称Compare and swap,字面意思:”比较并交换“能够比较和交换 某个寄存器中的值和内存中的值, 看是否相等, 如果相等, 则把另…

Java面试(JVM篇)——JVM 面试题合集 深入理解JVM虚拟机

关于什么是JVM? 作用: 运⾏并管理Java 源码⽂件所⽣成的Class⽂件,在不同的操作系统上安装不同的JVM ,从⽽实现了跨平台的保证。 ⼀般情况下,对于开发者⽽⾔,即使不熟悉JVM 的运⾏机制并不影响业务代码的…

lvs+keepalived: 高可用集群

lvskeepalived: 高可用集群 keepalived为lvs应运而生的高可用服务。lvs的调度器无法做高可用,于是keepalived软件。实现的是调度器的高可用。 但是:keepalived不是专门为集群服务的,也可以做其他服务器的高可用。 lvs的高可用集群&#xf…

防止消息丢失与消息重复——Kafka可靠性分析及优化实践

系列文章目录 上手第一关,手把手教你安装kafka与可视化工具kafka-eagle Kafka是什么,以及如何使用SpringBoot对接Kafka 架构必备能力——kafka的选型对比及应用场景 Kafka存取原理与实现分析,打破面试难关 防止消息丢失与消息重复——Kafka可…

Hadoop3教程(三十五):(生产调优篇)HDFS小文件优化与MR集群简单压测

文章目录 (168)HDFS小文件优化方法(169)MapReduce集群压测参考文献 (168)HDFS小文件优化方法 小文件的弊端,之前也讲过,一是大量占用NameNode的空间,二是会使得寻址速度…

Redis数据类型——list类型数据的扩展操作

1.list阻塞式数据获取 2.list类型数据业务场景

电脑软件:推荐一款非常强大的pdf阅读编辑软件

目录 一、软件简介 二、功能介绍 1、界面美观,打开速度快 2、可直接编辑pdf 3、非常强大好用的注释功能 4、很好用的页面组织和提取功能 5、PDF转word效果非常棒 6、强大的OCR功能 三、软件特色 四、软件下载 pdf是日常办公非常常见的文档格式,…

基于RM编译码的协作MIMO系统误码率matlab仿真,对比不同RM编译码参数

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2013b 3.部分核心程序 ...................................................................... [V1,N1,K1,I1] f…

05 MIT线性代数-转置,置换,向量空间Transposes, permutations, spaces

1. Permutations P: execute row exchanges becomes PA LU for any invertible A Permutations P identity matrix with reordered rows mn (n-1) ... (3) (2) (1) counts recordings, counts all nxn permuations 对于nxn矩阵存在着n!个置换矩阵 , 2. Transpose: 2.…

前端性能优化 - 虚拟滚动

一 需求背景 需求:在一个表格里面一次性渲染全部数据,不采用分页形式,每行数据都有Echart图插入。 问题:图表渲染卡顿 技术栈:Vue、Element UI 卡顿原因:页面渲染时大量的元素参与到了重排的动作中&#x…

代码随想录 Day26贪心算法01-上

目录 前言:贪心无套路 本质: 两个极端 贪心的小例子 贪心无套路!!! LeetCode T455 分发饼干 题目思路: 1.优先考虑胃口:大饼干喂饱大胃口 2.优先考虑饼干:小饼干先喂饱小胃口 前言:贪心无套路 本质: 局部最优去推导全局最优 两个极端 贪心算法的难度一般要么特别简单,要…

Mac Intellij Idea get/set方法快捷键

Control Retrun(回车键) Command n 参考: Mac Intellij Idea get/set方法快捷键-CSDN博客

2018年亚太杯APMCM数学建模大赛A题老年人平衡能力的实时训练模型求解全过程文档及程序

2018年亚太杯APMCM数学建模大赛 A题 老年人平衡能力的实时训练模型 原题再现 跌倒在老年人中很常见。跌倒可能会导致老年人出现许多并发症,因为他们的康复能力通常较差,因此副作用可能会使人衰弱,从而加速身体衰竭。此外,对跌倒…

ESP32C3 LuatOS TM1650①驱动测试

合宙TM1650驱动资料 TM1650.lua源码 引脚连接 TM1650ESP32C3SCLGPIO5SDAGPIO4 下载TM1650.lua源码,并以文件形式保存在项目文件夹中 驱动测试源码 --注意:因使用了sys.wait()所有api需要在协程中使用 -- 用法实例 PROJECT "ESP32C3_TM1650" VERSION …

数据结构:选择题+编程题(每日一练)

目录 选择题: 题一: 题二: 题三: 题四: 题五: 编程题: 题一:单值二叉树 思路一: 题二:二叉树的最大深度 思路一: 本人实力有限可能对…

KekeBlog项目实战后台模块(二)(已完结)

十一、后台模块-菜单列表 菜单指的是权限菜单,也就是一堆权限字符串 1. 查询菜单 1.1 接口分析 需要展示菜单列表,不需要分页。可以针对菜单名进行模糊查询。也可以针对菜单的状态进行查询。菜单要按照父菜单id和orderNum进行排序 请求方式 请求路径…

【QT开发(10)】QT 进程

文章目录 1.1 运行一个新进程1.2 QProcess 还可以对一些信号进行关联2 进程间通信2.1 使用共享内存实现进程通信2.2 演示 代码仓库参考 1.1 运行一个新进程 使用类 QProcess,允许将一个进程堪称一个顺序IO设备。 在Qt中,QProcess类是用于启动外部进程的…

Vue的MVVM实现原理

目录 前言 用法 代码和效果图 效果图 理解 高质量的使用 前言 MVVM是Model-View-ViewModel的缩写,是一种软件架构设计模式。Vue.js实现了这种设计模式,通过双向数据绑定和虚拟DOM技术,使得数据和视图能够快速响应彼此的变化。了解Vue的…

unity中方向的两种表示:欧拉角和四元数

欧拉角:简单来说就是你可以选择 0度~360度 的范围 四元数:在计算机图像学中,四元数用于物体的旋转,是一种复杂,但效率较高的旋转方式 Quaternion结构体代表一个四元数,包含一个标量和一个三维向量&#x…