自监督去噪:Noise2Self原理分析及实现 (Pytorch)

在这里插入图片描述

文章地址:https://arxiv.org/abs/1901.11365

代码地址: https://github.com/czbiohub-sf/noise2self

要点
  Noise2Self方法不需要信号先验信息、噪声估计信息和干净的训练数据。唯一的假设就是噪声在测量的不同维度上表现出的统计独立性,而真实信号表现出一定的相关性。Noiser2Self根据J-invariant提出了一种噪声校正的方案,可以应用到一系列的去噪方法之中,提高这些去噪方法的效果。

文章目录

      • 1. 方法原理
      • 2. 实验结果
        • 2.1 传统校正方法
        • 2.2 高斯噪声
        • 2.3 不同网络结构对比
      • 3. 代码实现
        • 3.1 J-invariant + 传统方法
        • 3.1 J-invariant + 神经网络
      • 4. 总结


1. 方法原理

如果所研究对象的空间的“潜在维度”远低于测量的维度,则可以隐式地学习该结构,对测量进行降噪,并在没有任何先验知识的情况下恢复信号,信号或噪声。

传统方法问题

  • 需要对噪声模式进行估计(如高斯噪声、结构性噪声),那么这些方法的效果就受限于对噪声模式的估计。
  • 需要对信号数据的结构有先验估计,但是这会限制去噪方法迁移到其他数据集。
  • 需要校准,因为平滑度、自相似性或矩阵的秩等超参数对去噪方法也会有影响

J-invariant 定义

假设 j ∈ J j \in J jJ, J J J是m维空间, 存在一个函数变换 f ( x ) J : R m ⇒ R m f(x)_J: R^m \Rightarrow R^m f(x)J:RmRm。如果这个变换过程不依赖于输入的 x J x_J xJ,那么称这个函数是具有J不变性质。

换个能看懂的说法:信号本身是相关的,假设噪声是互不相关的(条件独立的),那么我们用一个方法对这个噪声图片的部分数据进行处理,这个处理结果应该是和处理全部数据效果相同的,也就是使用部分维度信息达到恢复全局的效果。(需要强调的是我自己这里也没有理解特别透彻,如果有错误可以提出大家讨论)

假设 x x x(噪声图片) 是 y y y(干净图片)的无偏估计( E [ x ∣ y ] = y E[x|y] = y E[xy]=y), 噪声是整个域内是条件独立的,那么有:
E ∣ ∣ f ( x ) − x ∣ ∣ 2 2 = E ∣ ∣ f ( x ) − y ∣ ∣ 2 2 + E ∣ ∣ x − y ∣ ∣ 2 2 E||f(x) - x||_2^2 = E||f(x) - y||_2^2 + E||x - y||_2^2 E∣∣f(x)x22=E∣∣f(x)y22+E∣∣xy22

可以看到这里的无监督学习的损失等于 传统的监督学习的损失 加上噪声带来的偏差。

用J不变性描述一下 Noise2Noise就变为
如果现在有两个观测的噪声数据 x 1 = y + n 1 x_1 = y + n_1 x1=y+n1 , x 2 = y + n 2 x_2 = y + n_2 x2=y+n2
观测组合: x = ( x 1 , x 2 ) x = (x_1,x_2) x=(x1,x2)
信号组合 y = ( y , y ) ∈ R 2 m y = (y,y) \in R^{2m} y=(y,y)R2m
如果存在 J = { J 1 , J 2 } = { { 1 , . . . , m } , { m + 1 , . . . , 2 m } } J = \{J_1,J_2\} = \{\{1,...,m\},\{m+1,...,2m\}\} J={J1,J2}={{1,...,m},{m+1,...,2m}},那么有
f J ∗ ( x ) J 2 = E [ y ∣ x 1 ] f_{J}^*(x)_{J2} = E[y|x_1] fJ(x)J2=E[yx1]

就个人理解:J-不变性就是一个假设:如果噪声是条件独立的,那么监督去噪等价于无监督去噪加上一个噪声的偏差影响。

2. 实验结果

2.1 传统校正方法

首先将J不变性应用到 传统方法中:
传统的 “median filter”是将半径范围内所有像素的点都替换为中值
这里对比的是一种“donut filter”中值滤波方法:用中值替换除了中心像素的所有位置

那么“median filter”和“donut”甜甜圈模式的滤波器,其自监督的损失分别为
∣ ∣ g r ( x ) − x ∣ ∣ 2 ||g_r(x) - x||^2 ∣∣gr(x)x2

∣ ∣ f r ( x ) − x ∣ ∣ 2 ||f_r(x) - x||^2 ∣∣fr(x)x2

用图绘制出来:

从上图可以看出:median滤波器监督学习的损失随着半径的增加而线性增加,而donut滤波器在r = 3的时候其损失有一个最佳值。蓝色实线和蓝色虚线的垂直距离其实表征的是噪声带来的偏差,那么我们就发现了对于传统的滤波器,我们只能够更改输入来进行调整滤波效果,但是对于donut这类具有J-invariant性质的滤波器,我们可以通过一些原则来调整滤波效果(比如这里的距离r)

那么就可以给定一个比较通用的新滤波器形式了
f θ ( x ) J : = g θ ( 1 J . s ( x ) + 1 J c . x ) J f_{\theta}(x)_J := g_{\theta}(1_J . s(x) + 1_{Jc} . x)_J fθ(x)J:=gθ(1J.s(x)+1Jc.x)J

这里的 g θ g_{\theta} gθ表示传统的滤波其, s ( x ) s(x) s(x)表示将一些像素替换为周围其他像素的值/均值的一个操作。

个人理解:和Noise2Void那种盲点去噪的感觉相同,都是将输入的某些值进行替换,然后恢复那个点的信息。如果将这种方法应用到传统方法中可以帮我们找到最佳的滤波参数。

2.2 高斯噪声

2.3 不同网络结构对比

3. 代码实现

相关代码参考: https://github.com/czbiohub-sf/noise2self

3.1 J-invariant + 传统方法

这里以使用 J-invariant 到 中值滤波为例


加载相关库和数据

import sys
sys.path.append("..")
import numpy as np
import matplotlib.pyplot as plt
from skimage.morphology import disk
from skimage.filters import gaussian, median
from skimage import data, img_as_float, img_as_ubyte
from skimage.color import gray2rgb
from skimage.util import random_noise
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
from util import plot_grid, plot_images, expand# 加载原始数据
plt.rc('figure', figsize = (5,5))
show = lambda x: plt.imshow(x, cmap=plt.cm.gray)
image = data.camera()
show(image)
plt.show()# 加噪原始数据
np.random.seed(3)
noisy_image = img_as_ubyte(random_noise(image, mode = 'gaussian', var=0.01))
show(noisy_image)
plt.show()

在这里插入图片描述

定义中值滤波和donut中值滤波方法(引入J-invariant)

def mask_center(x):x[len(x)//2,len(x)//2] = 0return xplot_images([1-disk(4), 1-mask_center(disk(4))])

在这里插入图片描述
滤波并进行对比

radii = range(1, 7)
mask_med = np.array([median(noisy_image, mask_center(disk(i))) for i in radii])
med = np.array([median(noisy_image, disk(i)) for i in radii])plt.figure(figsize=(18,6))
for i in range(1,7):plt.subplot(2,6,i)show(mask_med[i-1])plt.title("r={}".format(radii[i-1]))if i ==1:plt.ylabel("donut")for i in range(1,7):plt.subplot(2,6,6+i)show(med[i-1])if i ==1:plt.ylabel("median filter")plt.show()

在这里插入图片描述

统计损失及相关参考指标

def stats(im_list, noisy_img, img):img = img_as_float(img)noisy_img = img_as_float(noisy_img)im_list = [img_as_float(x) for x in im_list]loss = [mse(x, noisy_img) for x in im_list]mse_gt = [mse(x, img) for x in im_list]psnr_gt = [psnr(x, img) for x in im_list]return loss, mse_gt, psnr_gtloss_med, mse_med, psnr_med = stats(med, noisy_image, image)
loss_mask_med, mse_mask_med, psnr_mask_med = stats(mask_med, noisy_image, image)
opt = radii[np.argmin(loss_mask_med)]plt.figure(figsize=(7,5))plt.plot(radii, loss_mask_med, label = 'self-supervised, donut median', color = 'C0')
plt.plot(radii, loss_med, label = 'self-supervised, ordinary median', color = 'C1')plt.axvline(radii[np.argmin(loss_mask_med)], color='k', linestyle='--')
plt.title('Calibrating a Median Filter')plt.plot(radii, mse_mask_med, label = 'reconstruction error, donut median', color = 'C0', linestyle='--')
plt.plot(radii, mse_med, label = 'reconstruction error, ordinary median', color = 'C1', linestyle='--')
plt.ylabel('MSE')
plt.xlabel('Radius of Median Filter')plt.yticks([0.002, 0.012])
plt.ylim(0, 0.0143)
plt.legend(loc='center right')
plt.show()

在这里插入图片描述

加入J-invariant之后可以帮助我们找到最佳的滤波参数(此处r = 3)

3.1 J-invariant + 神经网络

加载库及数据

from util import show, plot_images, plot_tensors
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Datasetmnist_train = MNIST(root='/data/mnist/', download = True,transform = transforms.Compose([transforms.ToTensor(),]), train = True)mnist_test = MNIST('/data/mnist/', download = True,transform = transforms.Compose([transforms.ToTensor(),]), train = False)

定义加噪方法

from torch import randn
def add_noise(img):return img + randn(img.size())*0.4class SyntheticNoiseDataset(Dataset):def __init__(self, data, mode='train'):self.mode = modeself.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):img = self.data[index][0]return add_noise(img), img
noisy_mnist_train = SyntheticNoiseDataset(mnist_train, 'train')
noisy_mnist_test = SyntheticNoiseDataset(mnist_test, 'test')
noisy, clean = noisy_mnist_train[0]
plot_tensors([noisy[0], clean[0]], ['Noisy Image', 'Clean Image'])

在这里插入图片描述

加mask也就是加盲点,需要恢复的也是这些盲点的信息

class Masker():"""Object for masking and demasking"""def __init__(self, width=3, mode='zero', infer_single_pass=False, include_mask_as_input=False):self.grid_size = widthself.n_masks = width ** 2self.mode = modeself.infer_single_pass = infer_single_passself.include_mask_as_input = include_mask_as_inputdef mask(self, X, i):phasex = i % self.grid_sizephasey = (i // self.grid_size) % self.grid_sizemask = pixel_grid_mask(X[0, 0].shape, self.grid_size, phasex, phasey)mask = mask.to(X.device)mask_inv = torch.ones(mask.shape).to(X.device) - maskif self.mode == 'interpolate':masked = interpolate_mask(X, mask, mask_inv)elif self.mode == 'zero':masked = X * mask_invelse:raise NotImplementedErrorif self.include_mask_as_input:net_input = torch.cat((masked, mask.repeat(X.shape[0], 1, 1, 1)), dim=1)else:net_input = maskedreturn net_input, maskdef __len__(self):return self.n_masksdef infer_full_image(self, X, model):if self.infer_single_pass:if self.include_mask_as_input:net_input = torch.cat((X, torch.zeros(X[:, 0:1].shape).to(X.device)), dim=1)else:net_input = Xnet_output = model(net_input)return net_outputelse:net_input, mask = self.mask(X, 0)net_output = model(net_input)acc_tensor = torch.zeros(net_output.shape).cpu()for i in range(self.n_masks):net_input, mask = self.mask(X, i)net_output = model(net_input)acc_tensor = acc_tensor + (net_output * mask).cpu()return acc_tensordef pixel_grid_mask(shape, patch_size, phase_x, phase_y):A = torch.zeros(shape[-2:])for i in range(shape[-2]):for j in range(shape[-1]):if (i % patch_size == phase_x and j % patch_size == phase_y):A[i, j] = 1return torch.Tensor(A)def interpolate_mask(tensor, mask, mask_inv):device = tensor.devicemask = mask.to(device)kernel = np.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], (0.5, 1.0, 0.5)])kernel = kernel[np.newaxis, np.newaxis, :, :]kernel = torch.Tensor(kernel).to(device)kernel = kernel / kernel.sum()filtered_tensor = torch.nn.functional.conv2d(tensor, kernel, stride=1, padding=1)return filtered_tensor * mask + tensor * mask_invmasker = Masker(width = 4, mode='interpolate')
net_input, mask = masker.mask(noisy.unsqueeze(0), 0)
plot_tensors([mask, noisy[0], net_input[0], net_input[0] - noisy[0]],["Mask", "Noisy Image", "Neural Net Input", "Difference"])

在这里插入图片描述

加载网络模型和进行训练

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.modules import ConvBlockclass BabyUnet(nn.Module):def __init__(self, n_channel_in=1, n_channel_out=1, width=16):super(BabyUnet, self).__init__()self.pool1 = nn.MaxPool2d(kernel_size=2)self.pool2 = nn.MaxPool2d(kernel_size=2)self.up1 = lambda x: F.interpolate(x, mode='bilinear', scale_factor=2, align_corners=False)self.up2 = lambda x: F.interpolate(x, mode='bilinear', scale_factor=2, align_corners=False)self.conv1 = ConvBlock(n_channel_in, width)self.conv2 = ConvBlock(width, 2*width)self.conv3 = ConvBlock(2*width, 2*width)self.conv4 = ConvBlock(4*width, 2*width)self.conv5 = ConvBlock(3*width, width)self.conv6 = nn.Conv2d(width, n_channel_out, 1)def forward(self, x):c1 = self.conv1(x)x = self.pool1(c1)c2 = self.conv2(x)x = self.pool2(c2)x = self.conv3(x)x = self.up1(x)x = torch.cat([x, c2], 1)x = self.conv4(x)x = self.up2(x)x = torch.cat([x, c1], 1)x = self.conv5(x)x = self.conv6(x)return x
model = BabyUnet()
loss_function = MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)data_loader = DataLoader(noisy_mnist_train, batch_size=32, shuffle=True)pbar = tqdm(data_loader)for i, batch in enumerate(pbar):noisy_images, clean_images = batchnet_input, mask = masker.mask(noisy_images, i)net_output = model(net_input)loss = loss_function(net_output*mask, noisy_images*mask)optimizer.zero_grad()loss.backward()optimizer.step()pbar.set_description("Iter:{},loss:{}".format(i,loss.item()))# if i % 10 == 0:#     print("Loss (", i, "): \t", round(loss.item(), 4))if i == 100:break

测试训练效果

test_data_loader = DataLoader(noisy_mnist_test,batch_size=32,shuffle=False,num_workers=3)
i, test_batch = next(enumerate(test_data_loader))
noisy, clean = test_batch
simple_output = model(noisy)
invariant_output = masker.infer_full_image(noisy, model)
idx = 3
plot_tensors([clean[idx], noisy[idx], simple_output[idx], invariant_output[idx]],["Ground Truth", "Noisy Image", "Single Pass Inference", "J-Invariant Inference"])

在这里插入图片描述

盲点网络训练后使用不同的输入(加盲点或者不加)得到的效果有些许差别,但是整体的去噪效果还可以。


4. 总结

  1. 引入J-invariant的概念到去噪工作之中,通过测试对比发现这种方法的自监督比传统方法有更好的效果,可以帮助传统方法寻找最佳的调整参数
  2. J-invariant的思路可以应用到传统去噪方法中或者先前的无监督、自监督学习工作之中,提高效果。(对比了Noise2Noiser和Noiser2Void方法)
  3. 和Noise2Void有异曲同工之妙,分析原理都是使用盲点网络的思想对输入数据进行mask,然后使用网络恢复这些盲点位置的信息。所以也存在和盲点网络相同的问题
    • 损失了盲点位置的信息
    • 盲点网络的假设:噪声是条件不相关的,信号是相关的;对于结构性的噪声的效果会较差。
    • 噪声零均值假设等假设限制了该方法应用到实际数据之中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/75356.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB /Simulink 快速开发STM32(使用st官方工具 STM32-MAT/TARGET),以及开发过程

配置好环境以后就是开发: stm32cube配置芯片,打开matlab添加ioc文件,写处理逻辑,生成代码,下载到板子中去。 配置需要注意事项: STM32CUBEMAX6.5.0 MABLAB2022BkeilV5.2 Matlab生成的代码CTRLB 其中关键的…

ClickHouse的安装启动

安装步骤 1.关闭防火墙 2.修改资源限制配置文件 2.1 路径:/etc/security/limits.conf 在末尾添加: * soft nofile 65536 #任何用户可以打开的最大的文件描述符数量,默认1024 这里的设置会限制tcp连接数 * hard nofile 65536 * soft nproc…

逃离城市热浪,寻觅25℃的夏天

“入伏”后,夏日里的热浪撩动着我们那颗躁动自由的心,趁着暑假走出巨大的城市“蒸笼”吧,甩掉高温和闷热,寻找避暑纳凉的最佳旅行地,感受不一样的夏日清凉感~ 在酷暑中,隐藏着很多不为人知的清凉打卡胜地&…

信必优行业服务能力-中国头部综合性证券公司

近期召开的国家高层会议提出 “要活跃资本市场,提振投资者信心”,明确了下一阶段资本市场发展新任务、新要求,资本市场有望呈现新气象、新风貌。各证券公司积极响应,全力推进资本市场回暖;同时各公司也借此东风修炼内功…

(7.28-8.3)【大数据新闻速递】《数字孪生工业软件白皮书》、《中国绿色算力发展研究报告》发布;华为ChatGPT要来了

【数字孪生工业软件白皮书(2023)】 近日,第七届数字孪生与智能制造服务学术会议成功举行,2023《数字孪生工业软件白皮书》在会上正式发布。《白皮书》在《Digital Twin》国际期刊专家顾问委员会指导下,由国家重点研发计…

<C语言> 预处理和宏

1.预定义符号 __FILE__ //进行编译的源文件 __LINE__ //文件当前的行号 __DATE__ //文件被编译的日期 __TIME__ //文件被编译的时间 __STDC__ //如果编译器遵循ANSI C,其值为1,否则未定义这些预定义符号都是C语言内置的。 举个例子&…

selenium 截屏

当前环境: Windows 10 Python 3.7 selenium 3.141.0 Google Chrome 115.0.5790.110 (64 位) from selenium import webdriver import base64if __name__ __main__:#driver webdriver.Chrome()driver.get(https://www.baidu.com/)# 1.…

M5ATOMS3基础04给ROS2发一个问候(micro-ROS)

参考以往部分历程: 1. esp32与ros2的欢乐启程 2021 2. micro-ROS之esp32与ros2资料(freertos) 2021 3. esp32发布机器人电池电压到ros2(micro-rosCoCube) 2022 4. CoCube和Micro-ROS简单案例演示 2022 不需要僵化的…

JavaScript数据结构与算法-初始栈结构

文章目录 一、初始栈结构1.1 特性1.2 注意事项 二、栈结构的封装2.1 封装简单栈结构2.2 利用栈将十进制转二进制 一、初始栈结构 1.1 特性 类似于汉诺塔,后进先出,每次只能操作栈顶的元素。关键词:压栈、退栈 简单示意图: 1.…

windows下tomcat无故宕机,检测http或https服务,并自动重启Tomcat服务

一、问题描述及解决原理 把项目发布到windows服务器中,如tomcat工程不稳定,会有无故宕机的问题。如果通过程序无法解决,并且重启tomcat服务能够生效的话,可以做一个自动检测并重启的脚本。 脚本通过检测tomcat对应的工程链接&…

Flask学习笔记_异步论坛(四)

Flask学习笔记_异步论坛(四) 1.配置和数据库链接1.exts.py里面实例化sqlalchemy数据库2.config.py配置app和数据库信息3.app.py导入exts和config并初始化到app上 2.创建用户模型并映射到数据库1.models/auth.py创建用户模型2.app.py导入模型并用flask-mi…

RabbitMQ 教程 | 第4章 RabbitMQ 进阶

👨🏻‍💻 热爱摄影的程序员 👨🏻‍🎨 喜欢编码的设计师 🧕🏻 擅长设计的剪辑师 🧑🏻‍🏫 一位高冷无情的编码爱好者 大家好,我是 DevO…

HCIA云计算 V5.0题库

云计算,这是近几年听得最多词了,云计算对于网络的发展帮助非常大,它自身所产生的价值是不可估量的!所以云计算的岗位对于很多IT公司来说,都是有一定地位的。华为认证云计算面向的对象很简单就是对云计算技术感兴趣的人…

【Spring】(四)Bean 的作用域和生命周期

文章目录 前言一、Bean 的作用域1.1 被修改的 Bean 案例1.2 作用域的定义1.3 Bean 的六种作用域1.4 Bean 作用域的设置 二、Spring 的执行流程 和 Bean 的生命周期2.1 Spring 的执行流程2.2 Bean 的生命周期2.3 Bean 生命周期的演示 前言 Bean 是 Spring 框架中的一个核心概念…

iphone内存不足导致白苹果?可以使用这2种办法解决!

因为iPhone内存不足没及时清理导致打开任何软件闪退,这时很多小伙伴会重启手机来解决闪退问题,但就会出现白苹果问题,无法正常进入手机系统、实现任何操作的一种状态。 内存不足导致iPhone白苹果的问题很常见,可以说是苹果最常见…

排序进行曲-v4.0

文章目录 小程一言快速排序步骤详细解释具体步骤 举例总结 复杂度分析时间复杂度分析:空间复杂度分析:注意 应用场景总结 实际举例结果总结 代码实现结果解释 小程一言 这篇文章是在排序进行曲3.0之后的续讲, 这篇文章主要是对快速排序进行细…

nodejs中的path.json和path.resolve的区别

nodejs中的path.json和path.resolve的区别 我们有多少次在 Node.js 项目中遇到过path.join()和path.resolve()却没有真正理解它们之间的区别?本文就讲解一下这两者的区别。 重要术语 首先我们先来看看几个术语,便于后续我们掌握这两者的差异。 字符串…

libcurl开源的、跨平台的网络传输库,用于在程序中实现数据传输功能的编译

文章目录 前言1、libcurl关键特点和功能2、没有使用openssl以及libssh2编译libcurl的文件和使用openssl和libssh2编译3、libcurl网络库的下载4、libcurl网络库的编译4.1、直接使用cmake编译,不使用 OpenSSL 和 libssh2库编译的出来的libcurl库4.2、使用 OpenSSL 和 …

peerDependency到底是什么

peerDependency到底是什么 正常开发中,我们经常接触到的是 package.json 中的 dependencies 和 devDependencies, 本文不对上面两个进行细节分析,让我们来看看 peerDependencies 是什么? 在 NPM v7 中,默认安装 peerDependencies…

虹科案例|如何分析设备故障时间和次数,打破生产瓶颈?

虹科设备绩效管理系统 保障生产设备的稳定性和可靠性 生产设备的稳定性和可靠性是保证企业正常生产的重要条件之一,设备故障的频发严重影响企业的正常生产,那么如何分析设备故障时间和次数,查找设备故障原因,协助企业打破生产瓶…