论文解读:DiAD之SG网络

目录

  • 一、SG网络功能介绍
  • 二、SG网络代码实现

一、SG网络功能介绍

DiAD论文最主要的创新点就是使用SG网络解决多类别异常检测中的语义信息丢失问题,那么它是怎么实现的保留原始图像语义信息的同时重建异常区域?

与稳定扩散去噪网络的连接: SG网络被设计为与稳定扩散(Stable Diffusion, SD)去噪网络相连接。SD去噪网络本身具有强大的图像生成能力,但可能无法在多类异常检测任务中保持图像的语义信息一致性。SG网络通过引入语义引导机制,使得在重构异常区域时能够参考并保留原始图像的语义上下文。整个框架图中,SG网络与去噪网络的连接如下图所示。
在这里插入图片描述在这里插入图片描述
这是论文给出的最终输出,我认为图中圈出来的地方有问题,应该改为SG网络的编码器才对。

语义一致性保持: SG网络在重构过程中,通过在不同尺度下处理噪声,并利用空间感知特征融合(Spatial-aware Feature Fusion, SFF)块融合特征,确保重建过程中保留语义信息。这样,即使在重构异常区域时,也能使修复后的区域与原始图像的语义上下文保持一致。
多尺度特征融合: SFF块将高尺度的语义信息集成到低尺度中,使得在保留原始正常样本信息的同时,能够处理大规模异常区域的重建。这种机制有助于在处理需要广泛重构的区域时,最大化重构的准确性,同时保持图像的语义一致性。从下图中可以看到,特征融合模块还是很好理解的。

在这里插入图片描述

与预训练特征提取器的结合: SG网络还与特征空间中的预训练特征提取器相结合。预训练特征提取器能够处理输入图像和重建图像,并在不同尺度上提取特征。通过比较这些特征,系统能够生成异常图(anomaly maps),这些图显示了图像中可能存在的异常区域,并给出了异常得分或置信度。这一步骤进一步验证了SG网络在保留语义信息方面的有效性。
避免类别错误: 相比于传统的扩散模型(如DDPM),SG网络通过引入类别条件解决了在多类异常检测任务中可能出现的类别错误问题。LDM虽然通过交叉注意力引入了条件约束,但在随机高斯噪声下去噪时仍可能丢失语义信息。SG网络则通过其语义引导机制,有效地避免了这一问题。

二、SG网络代码实现

这部分代码大概有300行

class SemanticGuidedNetwork(nn.Module):def __init__(self,image_size,in_channels,model_channels,hint_channels,num_res_blocks,attention_resolutions,dropout=0,channel_mult=(1, 2, 4, 8),conv_resample=True,dims=2,use_checkpoint=False,use_fp16=False,num_heads=-1,num_head_channels=-1,num_heads_upsample=-1,use_scale_shift_norm=False,resblock_updown=False,use_new_attention_order=False,use_spatial_transformer=False,  # custom transformer supporttransformer_depth=1,  # custom transformer supportcontext_dim=None,  # custom transformer supportn_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq modellegacy=True,disable_self_attentions=None,num_attention_blocks=None,disable_middle_self_attn=False,use_linear_in_transformer=False,):super().__init__()if use_spatial_transformer:assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'if context_dim is not None:assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'from omegaconf.listconfig import ListConfigif type(context_dim) == ListConfig:context_dim = list(context_dim)if num_heads_upsample == -1:num_heads_upsample = num_headsif num_heads == -1:assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'if num_head_channels == -1:assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'self.dims = dimsself.image_size = image_sizeself.in_channels = in_channelsself.model_channels = model_channelsif isinstance(num_res_blocks, int):self.num_res_blocks = len(channel_mult) * [num_res_blocks]else:if len(num_res_blocks) != len(channel_mult):raise ValueError("provide num_res_blocks either as an int (globally constant) or ""as a list/tuple (per-level) with the same length as channel_mult")self.num_res_blocks = num_res_blocksif disable_self_attentions is not None:# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or notassert len(disable_self_attentions) == len(channel_mult)if num_attention_blocks is not None:assert len(num_attention_blocks) == len(self.num_res_blocks)assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "f"This option has LESS priority than attention_resolutions {attention_resolutions}, "f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "f"attention will still not be set.")self.attention_resolutions = attention_resolutionsself.dropout = dropoutself.channel_mult = channel_multself.conv_resample = conv_resampleself.use_checkpoint = use_checkpointself.dtype = th.float16 if use_fp16 else th.float32self.num_heads = num_headsself.num_head_channels = num_head_channelsself.num_heads_upsample = num_heads_upsampleself.predict_codebook_ids = n_embed is not Nonetime_embed_dim = model_channels * 4self.time_embed = nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))])self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])self.input_hint_block = TimestepEmbedSequential(conv_nd(dims, hint_channels, 16, 3, padding=1),nn.SiLU(),conv_nd(dims, 16, 16, 3, padding=1),nn.SiLU(),conv_nd(dims, 16, 32, 3, padding=1, stride=2),nn.SiLU(),conv_nd(dims, 32, 32, 3, padding=1),nn.SiLU(),conv_nd(dims, 32, 96, 3, padding=1, stride=2),nn.SiLU(),conv_nd(dims, 96, 96, 3, padding=1),nn.SiLU(),conv_nd(dims, 96, 256, 3, padding=1, stride=2),nn.SiLU(),zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)))self._feature_size = model_channelsinput_block_chans = [model_channels]ch = model_channelsds = 1for level, mult in enumerate(channel_mult):for nr in range(self.num_res_blocks[level]):layers = [ResBlock(ch,time_embed_dim,dropout,out_channels=mult * model_channels,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,)]ch = mult * model_channelsif ds in attention_resolutions:if num_head_channels == -1:dim_head = ch // num_headselse:num_heads = ch // num_head_channelsdim_head = num_head_channelsif legacy:# num_heads = 1dim_head = ch // num_heads if use_spatial_transformer else num_head_channelsif exists(disable_self_attentions):disabled_sa = disable_self_attentions[level]else:disabled_sa = Falseif not exists(num_attention_blocks) or nr < num_attention_blocks[level]:layers.append(AttentionBlock(ch,use_checkpoint=use_checkpoint,num_heads=num_heads,num_head_channels=dim_head,use_new_attention_order=use_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint))self.input_blocks.append(TimestepEmbedSequential(*layers))self.zero_convs.append(self.make_zero_conv(ch))self._feature_size += chinput_block_chans.append(ch)if level != len(channel_mult) - 1:out_ch = chself.input_blocks.append(TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,out_channels=out_ch,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,down=True,)if resblock_updownelse Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)))ch = out_chinput_block_chans.append(ch)self.zero_convs.append(self.make_zero_conv(ch))ds *= 2self._feature_size += chif num_head_channels == -1:dim_head = ch // num_headselse:num_heads = ch // num_head_channelsdim_head = num_head_channelsif legacy:# num_heads = 1dim_head = ch // num_heads if use_spatial_transformer else num_head_channelsself.middle_block = TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,),AttentionBlock(ch,use_checkpoint=use_checkpoint,num_heads=num_heads,num_head_channels=dim_head,use_new_attention_order=use_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attnch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint),ResBlock(ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,),)self.middle_block_out = self.make_zero_conv(ch)self._feature_size += ch#SFF Blockself.down11 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down12 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down13 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down21 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down22 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down23 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down31 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down32 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down33 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.silu = nn.SiLU()def make_zero_conv(self, channels):return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))def forward(self, x, hint, timesteps, context, **kwargs):t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)emb = self.time_embed(t_emb)guided_hint = self.input_hint_block(hint, emb, context)outs = []h = x.type(self.dtype)for module, zero_conv in zip(self.input_blocks, self.zero_convs):if guided_hint is not None:h = module(h, emb, context)h += guided_hintguided_hint = Noneelse:h = module(h, emb, context)outs.append(zero_conv(h, emb, context))#SFF Block Implementationouts[9] = self.silu(outs[9]+self.down11(outs[6])+self.down21(outs[7])+self.down31(outs[8]))outs[10] = self.silu(outs[10]+self.down12(outs[6])+self.down22(outs[7])+self.down32(outs[8]))outs[11] = self.silu(outs[11]+self.down13(outs[6])+self.down23(outs[7])+self.down33(outs[8]))h = self.middle_block(h, emb, context)outs.append(self.middle_block_out(h, emb, context))return outs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/387053.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

将 magma example 改写成 cusolver example eqrf

1&#xff0c;简单安装Magma 1.1 下载编译 OpenBLAS $ git clone https://github.com/OpenMathLib/OpenBLAS.git $ cd OpenBLAS/ $ make -j DEBUG1 $ make install PREFIX/home/hipper/ex_magma/local_d/OpenBLAS/1.2 下载编译 magma $ git clone https://bitbucket.org/icl…

【Kubernetes】二进制部署k8s集群(中)之cni网络插件flannel和calico

&#xff01;&#xff01;&#xff01;继续上一篇实验部署&#xff01;&#xff01;&#xff01; 目录 一.k8s的三种网络模式 1.Pod 内容器与容器之间的通信 2.同一个 Node 内 Pod 之间的通信 3.不同 Node 上 Pod 之间的通信 二.k8s的三种接口 三.Flannel 网络插件 1.U…

美摄科技企业级视频拍摄与编辑SDK解决方案

在数字化浪潮汹涌的今天&#xff0c;视频已成为企业传递信息、塑造品牌、连接用户不可或缺的强大媒介。为了帮助企业轻松驾驭这一视觉盛宴的制作过程&#xff0c;美摄科技凭借其在影视级非编技术领域的深厚积累&#xff0c;推出了面向企业的专业视频拍摄与编辑SDK解决方案&…

每日OJ_牛客CM26 二进制插入

目录 牛客CM26 二进制插入 解析代码 牛客CM26 二进制插入 二进制插入_牛客题霸_牛客网 解析代码 m:1024&#xff1a;100000000 00 n:19 &#xff1a; 10011 要把n的二进制值插入m的第j位到第i位&#xff0c;只需要把n先左移j位&#xff0c;然后再进行或运算&#xff08;|&am…

高品质定制线缆知名智造品牌推荐-精工电联:高压线缆行业定制服务的领航者

定制线缆源头厂家推荐-精工电联&#xff1a;高压线缆行业定制服务的领航者 在当今这个高度信息化的社会&#xff0c;电力传输与分配系统的稳定运行至关重要。作为连接各个电力设备的纽带&#xff0c;高压线缆的质量直接关系到电力系统的安全性和稳定性。在定制高压线缆行业中&a…

android(安卓)最简单明了解释版本控制之MinSdkVersion、CompileSdkVersion、TargetSdkVersion

1、先明白几个概念 &#xff08;1&#xff09;平台版本&#xff08;Android SDK版本号&#xff09; 平台版本也就是我们平时说的安卓8、安卓9、安卓10 &#xff08;2&#xff09;API级别&#xff08;API Level&#xff09; Android 平台提供的框架 API 被称作“API 级别” …

Hugo 部署与自动更新(Git)

文章目录 Nginx部署Hugonginx.confhugo.conf Hugo自动更新Hugo自动更新流程添加访问令牌添加web hookrust实现自动更新接口 Nginx部署Hugo nginx.conf user nginx; worker_processes auto;error_log /var/log/nginx/error.log notice; pid /var/run/nginx.pid;even…

C++STL简介(三)

目录 1.vector的模拟实现 1.1begin&#xff08;&#xff09; 1.2end&#xff08;&#xff09; 1.3打印信息 1.4 reserve&#xff08;&#xff09; 1.5 size&#xff08;&#xff09; 1.6 capacity&#xff08;&#xff09; 1.7 push_back() 1.8[ ] 1.9 pop_back() 1.10 insert&…

合并K个有序链表

题目 给你一个链表数组&#xff0c;每个链表都已经按升序排列。 请你将所有链表合并到一个升序链表中&#xff0c;返回合并后的链表。 示例1&#xff1a; 输入&#xff1a; 输出&#xff1a; 示例2&#xff1a; 输入&#xff1a; 输出&#xff1a; 示例3&#xff1a; 输入&…

【音视频之SDL2】Windows配置SDL2项目模板

文章目录 前言 SDL2 简介核心功能 Windows配置SDL2项目模板下载SDL2编译好的文件VS配置SDL2 测试代码效果展示 总结 前言 在开发跨平台的音视频应用程序时&#xff0c;SDL2&#xff08;Simple DirectMedia Layer 2&#xff09;是一个备受欢迎的选择。SDL2 是一个开源库&#x…

自研Vue3开源Tree组件:节点拖拽bug修复

当dropType为after&#xff0c;且dropNode为父节点时&#xff0c;bug出现了&#xff1a; bug原因&#xff1a;插入扁平化列表的位置insertIndex计算的不对&#xff1a; 正确的逻辑&#xff0c;同inner要算上子孙节点所占的位置&#xff1a; bug修复&#xff01;

「数组」实现动态数组的功能(C++)

概述 动态数组&#xff0c;顾名思议即可变长度的数组。数组这种数据结构的实现是在栈空间或堆空间申请一段连续的可操作区域。 实现可变长度的动态数组结构&#xff0c;应该有以下操作&#xff1a;申请一段足够长的空间&#xff0c;如果数据的存入导致空间已满&#xff0c;则…

PackagesNotFoundError 错误表明 conda 在当前使用的镜像源中找不到 contourpy 版本 1.2.1。以下是可能的解决方法:

PackagesNotFoundError 错误表明 conda 在当前使用的镜像源中找不到 contourpy 版本 1.2.1。以下是可能的解决方法&#xff1a; PackagesNotFoundError 错误表明 conda 在当前使用的镜像源中找不到 contourpy 版本 1.2.1。以下是可能的解决方法&#xff1a; 1. 更换镜像源 虽…

rust 桌面 sip 软电话(基于tauri 、pjsip库)

本文尝试下rust 的tauri 桌面运用 原因在于体积小 1、pjsip 提供了rust 接口官方的 rust demo 没编译出来 在git找了个sip-phone-rs-master https://github.com/Charles-Schleich/sip-phone-rs 可以自己编译下pjsip lib库替换该项目的lib 2、创建一个tauri demo 引用 [depe…

计算机毕业设计选题推荐-某炼油厂盲板管理系统-Java/Python项目实战

✨作者主页&#xff1a;IT研究室✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

实战:Zookeeper 简介和单点部署ZooKeeper

Zookeeper 简介 ZooKeeper是一个开源的分布式协调服务&#xff0c;它是Apache软件基金会下的一个项目&#xff0c;旨在解决分布式系统中的协调和管理问题。以下是ZooKeeper的详细简介&#xff1a; 一、基本定义 ZooKeeper是一个分布式的、开放源码的分布式应用程序协调服务&a…

SpringBoot的基础配置

目录 SpringBoot快速搭建web程序 第一步&#xff1a;导包 第二步&#xff1a;配置SpringBoot引导类 第三步&#xff1a;编写controller类 第四步&#xff1a;在SpirngBoot引导类中启动项目 起步依赖 SpringBoot基础配置 配置文件格式 yaml语法规则 读取yml配置文件的方…

UE5+OpenCV配置(Windows11系统)

一、概述 因为需要在UE5中使用OpenCV这些工具进行配置&#xff0c;所以在网络上参考借鉴一些资料进行配置。查询到不少的资料&#xff0c;最后将其配置成功。在这里顺便记录一下自己的配置成功的过程。 二、具体过程 &#xff08;一&#xff09;版本 使用Windows11系统、UE5.…

ONLYOFFICE 协作空间 2.6 已发布:表单填写房间、LDAP、优化房间和文件管理等

更新后的 ONLYOFFICE 协作空间带来了超过 20 项新功能和优化&#xff0c;让工作更加高效和舒适。阅读本文了解详情。 表单填写房间 这次更新增加了一种新的房间类型&#xff0c;可在 ONLYOFFICE 协作空间中组织简单的表单填写流程。 通过表单填写房间&#xff0c;目前可以完成…

将控制台内容输出到文本文件

示例代码&#xff1a; Imports System.IO Module Module1Sub Main()Dim fs As New FileStream("D:\Desktop\test\输出结果.txt", FileMode.Create, FileAccess.Write, FileShare.None)Dim sw As New StreamWriter(fs)Console.SetOut(sw)Console.SetError(sw)For i …