【Block总结】掩码窗口自注意力 (M-WSA)

在这里插入图片描述

摘要

论文链接:https://arxiv.org/pdf/2404.07846
论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在处理图像时的局限性,特别是在图像去噪和恢复任务中。M-WSA 通过引入掩码机制,确保在计算注意力时遵循盲点要求,从而避免信息泄露。

设计原理

  1. 窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,将输入图像划分为多个不重叠的窗口。在每个窗口内,计算自注意力以捕捉局部特征。这种方法的计算复杂度相对较低,适合处理高分辨率图像。

  2. 掩码机制:为了满足盲点要求,M-WSA 在计算注意力时应用了掩码。具体而言,掩码限制了每个像素只能关注其窗口内的特定像素,从而避免了对盲点信息的访问。这一设计确保了网络在去噪时不会泄露噪声信息。

  3. 扩张卷积模拟:M-WSA 的掩码设计模仿了扩张卷积的感受野,使得网络能够在保持计算效率的同时,捕捉到更大范围的上下文信息。这种方法有效地扩展了网络的感受野,增强了特征提取能力。
    在这里插入图片描述

优势

  • 高效性:通过限制注意力计算在窗口内,M-WSA 显著降低了计算复杂度,使其适用于大规模图像处理任务。

  • 信息保护:掩码机制确保了盲点信息不被泄露,从而提高了去噪效果,特别是在处理具有空间相关噪声的图像时。

  • 灵活性:M-WSA 可以与其他网络架构结合使用,增强其在各种视觉任务中的表现,尤其是在自我监督学习和图像恢复领域。

实验结果

在多个真实世界的图像去噪数据集上进行的实验表明,M-WSA 显著提高了去噪性能,超越了传统的卷积网络和其他自注意力机制。这一结果表明,M-WSA 在处理复杂噪声模式时具有良好的适应性和有效性。

代码

Masked Window-Based Self-Attention (M-WSA) 通过结合窗口自注意力和掩码机制,为图像去噪和恢复任务提供了一种有效的解决方案。其设计不仅提高了计算效率,还确保了信息的安全性,展示了在自我监督学习中的广泛应用潜力。代码:

import torch
import torch.nn as nn
from einops import rearrange
from torch import einsumdef to(x):return {'device': x.device, 'dtype': x.dtype}def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_hclass FixedPosEmb(nn.Module):def __init__(self, window_size, overlap_window_size):super().__init__()self.window_size = window_sizeself.overlap_window_size = overlap_window_sizeattention_mask_table = torch.zeros((window_size + overlap_window_size - 1),(window_size + overlap_window_size - 1))attention_mask_table[0::2, :] = float('-inf')attention_mask_table[:, 0::2] = float('-inf')attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size)coords_w = torch.arange(self.window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten_1 = torch.flatten(coords, 1)  # 2, Wh*Wwcoords_h = torch.arange(self.overlap_window_size)coords_w = torch.arange(self.overlap_window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten_2 = torch.flatten(coords, 1)relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.overlap_window_size - 1  # shift to start from 0relative_coords[:, :, 1] += self.overlap_window_size - 1relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(1, self.window_size ** 2, self.overlap_window_size ** 2), requires_grad=False)def forward(self):return self.attention_maskclass DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return outif __name__ == "__main__":dim = 64window_size = 8overlap_ratio = 0.5num_heads = 2dim_head = 16# 初始化 DilatedOCA 模块oca_attention = DilatedOCA(dim=dim,window_size=window_size,overlap_ratio=overlap_ratio,num_heads=num_heads,dim_head=dim_head,bias=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")oca_attention = oca_attention.to(device)print(oca_attention)x = torch.randn(1, 32, 640, 480).to(device)# 前向传播output = oca_attention(x)print("input张量形状:", x.shape)print("output张量形状:", output.shape)

DilatedOCA模块详解

代码结构

import torch
import torch.nn as nn
from einops import rearrange
  • 导入库:首先导入 PyTorch 和 einops 库。einops 用于简化张量的重排操作。

模块定义

class DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
  • 初始化方法__init__ 方法定义了模块的结构。

    • dim:输入特征的通道数。

    • window_size:窗口的大小,用于空间注意力计算。

    • overlap_ratio:重叠窗口的比例,决定了窗口之间的重叠程度。

    • num_heads:空间注意力的头数。

    • dim_head:每个头的维度。

  • 层的定义

    • self.unfold:用于将输入张量展开为重叠窗口的操作。

    • self.qkv:一个 1x1 的卷积层,用于生成查询(Q)、键(K)和值(V)三个特征图。

    • self.project_out:一个 1x1 的卷积层,用于将输出特征映射回原始通道数。

    • self.rel_pos_embself.fixed_pos_emb:用于位置编码的模块,增强模型对空间位置的感知。

前向传播

def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return out
  • 输入形状x 的形状为 (batch_size, channels, height, width),其中 b 是批量大小,c 是通道数,hw 是图像的高度和宽度。

  • 特征提取

    • qkv = self.qkv(x):通过 qkv 层生成 Q、K、V 特征图。

    • qs, ks, vs = qkv.chunk(3, dim=1):将 Q、K、V 特征图沿通道维度分离。

  • 空间注意力计算

    • qs 被重排为适合空间注意力计算的格式。

    • ksvs 通过 unfold 操作展开为重叠窗口。

  • 分头处理

    • 使用 einops.rearrange 将 Q、K、V 的形状调整为适合多头自注意力计算的格式。
  • 计算注意力

    • qs = qs * self.scale:对 Q 进行缩放以提高稳定性。

    • spatial_attn = (qs @ ks.transpose(-2, -1)):计算注意力分数。

    • spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb():添加位置编码以增强空间感知。

    • spatial_attn = spatial_attn.softmax(dim=-1):对注意力分数进行 softmax 归一化。

  • 输出计算

    • out = (spatial_attn @ vs):使用注意力权重对 V 进行加权求和,得到最终输出。
  • 重排输出

    • out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...):将输出重排回原始形状。
  • 最终投影

    • out = self.project_out(out):通过投影层将输出映射回原始通道数。

总结

DilatedOCA 模块结合了扩张卷积和空间注意力机制,通过重叠窗口的设计增强了对图像局部特征的捕捉能力。该模块在图像处理任务中具有广泛的应用潜力,尤其是在需要精细特征提取的场景中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/3019.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

卷积神经05-GAN对抗神经网络

卷积神经05-GAN对抗神经网络 使用Python3.9CUDA11.8Pytorch实现一个CNN优化版的对抗神经网络 简单的GAN图片生成 CNN优化后的图片生成 优化模型代码对比 0-核心逻辑脉络 1)Anacanda使用CUDAPytorch2)使用本地MNIST进行手写图片训练3)…

怎么在iPhone手机上使用便签进行记录?

宝子们,在这个快节奏的时代,灵感的火花总是一闪而过,待办事项也常常让人应接不暇。好在咱们的 iPhone手机便签超给力,能满足各种记录需求!今天就来给大家分享一下,如何在 iPhone 手机上巧用便签&#xff0c…

基于微信小程序的摄影竞赛系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

【从零开始使用系列】StyleGAN2:开源图像生成网络——环境搭建与基础使用篇(附大量测试图)

StyleGAN2 是英伟达团队 NVIDIA 提出的生成对抗网络(GAN)的一种改进版本。 它通过创新的网络架构,能够生成细节丰富、逼真的图像,特别在高频细节(如皮肤纹理、光照等)的表现上表现卓越。与传统 GAN 相比&am…

redis(2:数据结构)

1.String 2.key的层级格式 3.Hash 4.List 5.Set 6.SortedSet

LabVIEW 程序中的 R6025 错误

R6025错误 通常是 运行时库 错误,特别是与 C 运行时库 相关。这种错误通常会在程序运行时出现,尤其是在使用 C 编译的程序或依赖 C 运行时库的程序时。 ​ 可能的原因: 内存访问冲突: R6025 错误通常是由于程序在运行时访问无效内…

前端【2】html添加样式、CSS选择器

一、为html添加样式的三种方法 1、内部样式 2、外部样式 3、行内样式 二、css的使用--css选择器 1、css基本选择器 元素选择器 属性选择器 id选择器 class/类选择器 通配符选择器 2、群组选择器-多方面筛选 3、关系选择器 后代选择器【包含选择器】 子元素选择器…

【Elasticsearch】全文搜索与相关性排序

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…

【算法】枚举

枚举 普通枚举1.铺地毯2.回文日期3.扫雷 二进制枚举1.子集2.费解的开关3.Even Parity 顾名思义,就是把所有情况全都罗列出来,然后找出符合题目要求的那一个。因此,枚举是一种纯暴力的算法。一般情况下,枚举策略都是会超时的。此时…

51单片机——DS18B20温度传感器

由于DS18B20数字温度传感器是单总线接口,所以需要使用51单片机的一个IO口模拟单总线时序与DS18B20通信,将检测的环境温度读取出来 1、DS18B20模块电路 传感器接口的单总线管脚接至单片机P3.7IO口上 2、DS18B20介绍 2.1 DS18B20外观实物图 管脚1为GN…

云手机技术怎么实现的?

前言 随着亚矩阵云手机在跨境电商、海外社媒矩阵搭建、出海运营、海外广告投放、国内新媒体矩阵运营、品牌应用矩阵运营等领域内的普及和使用,云手机的理念已经被越来越多人所接受和认同。今天我们就一起来浅析一下,到底云手机的技术是怎么实现的&#…

HTML中link的用法

一点寒芒先到,随后,抢出如龙! 对于本人而言,这篇笔记内容有些扩展了,有些还未学到的也用上了,但是大概可以使用的明白,坚持下去,相信一定可以建设一个稳固的根基。 该文章为个人成…

闪豆多平台视频批量下载器

1. 视频链接获取与解析 首先,在哔哩哔哩网页中随意点击一个视频,比如你最近迷上了一个UP主的美食制作视频,想要下载下来慢慢学。点击视频后,复制视频页面的链接。复制完成后,不要急着关闭浏览器,因为接下来…

Vulnhub DC-8靶机攻击实战(一)

导语   Vulnhub DC-8靶机教程来了,好久没有更新打靶的教程了,这次我们在来更新一期关于Vulnhub DC-8的打靶训练,如下所示。 安装并且启动靶机 安装并且启动靶机,如下所示。 开始信息采集 进入到Kali中,通过如下的命令来查找到靶机的IP地址。 arp-scan -l根据上面的结…

JWT在线解密/解码 - 加菲工具

JWT在线解密/解码 首先进入加菲工具 选择 “JWT 在线解密/解码” https://www.orcc.online 或者直接进入JWT 在线解密/解码 https://www.orcc.online/tools/jwt 进入功能页面 使用 输入对应的jwt内容,点击解码按钮即可

换了城市ip属地会变吗?为什么换了城市IP属地不变

当我们跨越城市的界限,从一个地方迁移到另一个地方时,许多日常使用的网络服务和应用程序都会感知到这种变化,其中一个显著的现象就是IP属地的变化。IP属地,即IP地址所在的地理位置信息,它通常与互联网服务提供商&#…

如何在谷歌浏览器中设置自定义安全警告

随着网络环境的日益复杂,浏览器的安全问题也愈发引人关注。谷歌浏览器作为一款广泛使用的浏览器,其自定义安全警告功能为用户提供了更加个性化和安全的浏览体验。本文将详细介绍如何在谷歌浏览器中设置自定义安全警告,帮助用户更好地保护自己…

深度学习中的卷积和反卷积(四)——卷积和反卷积的梯度

本系列已完结,全部文章地址为: 深度学习中的卷积和反卷积(一)——卷积的介绍 深度学习中的卷积和反卷积(二)——反卷积的介绍 深度学习中的卷积和反卷积(三)——卷积和反卷积的计算 …

Mongodb相关内容

Mongodb相关内容 1、Windows平台安装2、Linux平台安装3、基本常用命令文档更新删除文档分页查询索引 pymongo操作 客户端下载:https://download.csdn.net/download/guoqingru0311/90273435 1、Windows平台安装 方式一: 方式2: 方式3&#…

RabbitMQ前置概念

文章目录 1.AMQP协议是什么?2.rabbitmq端口介绍3.消息队列的作用和使用场景4.rabbitmq工作原理5.整体架构核心概念6.使用7.消费者消息推送限制(work模型)8.fanout交换机9.Direct交换机10.Topic交换机(推荐)11.声明队列…