Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance
Github

摘要

近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导(CFG)。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中,我们提出了一种新颖的采样引导方法,称为Perturbed-Attention Guidance(PAG),它能在无条件和条件设置下提高扩散样本的质量,并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本,这是考虑到自注意力机制捕捉结构信息的能力,并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中,PAG 在条件甚至无条件场景下都显著提高了样本质量。此外,在诸如空提示的 ControlNet 以及图像修复(如修补和去模糊)等现有引导(如 CG 或 CFG)无法充分利用的各种下游任务中,PAG 也显著提高了基线性能。
在这里插入图片描述
研究表明,在diffusion U-Net的self-attention 模块中,query-key 主要影响structure ,values主要影响appearance。
在这里插入图片描述
如果直接扰动Vt 的话,会导致 out-of-distribution (OOD),因此选择使用单位矩阵替换query-key 部分。
在这里插入图片描述
在这里插入图片描述
那么具体扰动Unet的哪一部分呢?作者使用了5k个样本,在PAG guidance scale s = 2.5 and DDIM 25 step的条件下,表现最好的是mid-block “m0”
在这里插入图片描述

代码

Diffusers 已经支持PAG用在多种任务中,并且可以和ControlNet、 IP-Adapter 一起使用。

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torchpipeline = AutoPipelineForText2Image.from_pretrained("~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0",enable_pag=True,  ##addpag_applied_layers=["mid"], ##addtorch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()prompt = "an insect robot preparing a delicious meal, anime style"
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(prompt=prompt,num_inference_steps=25,guidance_scale=7.0,generator=generator,pag_scale=2.5,
).imagesimages[0].save("pag.jpg")

PAG代码细节

如果同时使用PAG和CFG,那么输入到Unet中prompt_embeds定义如下,也就是[uncond,cond,cond]

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):cond = torch.cat([cond] * 2, dim=0)if do_classifier_free_guidance:cond = torch.cat([uncond, cond], dim=0)return cond

PAGCFGIdentitySelfAttnProcessor2_0计算,其中[uncond,cond]正常计算SA,第二个cond则计算PSA。

class PAGCFGIdentitySelfAttnProcessor2_0:r"""Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).PAG reference: https://arxiv.org/abs/2403.17377"""def __init__(self):if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: Optional[torch.FloatTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,temb: Optional[torch.FloatTensor] = None,) -> torch.Tensor:residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# chunkhidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])# original pathbatch_size, sequence_length, _ = hidden_states_org.shapeif attention_mask is not None:attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# scaled_dot_product_attention expects attention_mask shape to be# (batch, heads, source_length, target_length)attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])if attn.group_norm is not None:hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states_org)key = attn.to_k(hidden_states_org)value = attn.to_v(hidden_states_org)inner_dim = key.shape[-1]head_dim = inner_dim // attn.headsquery = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# the output of sdp = (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states_org = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)hidden_states_org = hidden_states_org.to(query.dtype)# linear projhidden_states_org = attn.to_out[0](hidden_states_org)# dropouthidden_states_org = attn.to_out[1](hidden_states_org)if input_ndim == 4:hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)# perturbed path (identity attention)batch_size, sequence_length, _ = hidden_states_ptb.shapeif attn.group_norm is not None:hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)value = attn.to_v(hidden_states_ptb)hidden_states_ptb = valuehidden_states_ptb = hidden_states_ptb.to(query.dtype)# linear projhidden_states_ptb = attn.to_out[0](hidden_states_ptb)# dropouthidden_states_ptb = attn.to_out[1](hidden_states_ptb)if input_ndim == 4:hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)# cathidden_states = torch.cat([hidden_states_org, hidden_states_ptb])if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

经过Unet后,noise_pred的计算方法。

    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False):r"""Apply perturbed attention guidance to the noise prediction.Args:noise_pred (torch.Tensor): The noise prediction tensor.do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.guidance_scale (float): The scale factor for the guidance term.t (int): The current time step.return_pred_text (bool): Whether to return the text noise prediction.Returns:Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applyingperturbed attention guidance and the text noise prediction."""pag_scale = self._get_pag_scale(t)if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)noise_pred = (noise_pred_uncond+ guidance_scale * (noise_pred_text - noise_pred_uncond)+ pag_scale * (noise_pred_text - noise_pred_perturb))else:noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)if return_pred_text:return noise_pred, noise_pred_textreturn noise_pred

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/504344.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(概率论)无偏估计

参考文章:(15 封私信 / 51 条消息) 什么是无偏估计? - 知乎 (zhihu.com) 首先,第一个回答中,马同学图解数学讲解得很形象, 我的概括是:“注意,有一个总体的均值u。然后,如果抽样n个&…

USB 驱动开发 --- Gadget 设备连接 Windows 免驱

环境信息 测试使用 DuoS(Arm CA53, Linux 5.10) 搭建方案验证环境,使用 USB sniff Wirekshark 抓包分析,合照如下: 注:左侧图中设备:1. 蓝色,USB sniff 非侵入工 USB 抓包工具;2. …

OSPF - 2、3类LSA(Network-LSA、NetWork-Sunmmary-LSA)

前篇博客有对常用LSA的总结 2类LSA(Network-LSA) DR产生泛洪范围为本区域 作用:  描述MA网络拓扑信息和网络信息,拓扑信息主要描述当前MA网络中伪节点连接着哪几台路由。网络信息描述当前网络的 掩码和DR接口IP地址。 影响邻居建立中说到…

开放词汇检测新晋SOTA:地瓜机器人开源DOSOD实时检测算法

在计算机视觉领域,目标检测是一项关键技术,旨在识别图像或视频中感兴趣物体的位置与类别。传统的闭集检测长期占据主导地位,但近年来,开放词汇检测(Open-Vocabulary Object Detection-OVOD 或者 Open-Set Object Detec…

【Ubuntu】 Ubuntu22.04搭建NFS服务

安装NFS服务端 sudo apt install nfs-kernel-server 安装NFS客户端 sudo apt install nfs-common 配置/etc/exports sudo vim /etc/exports 第一个字段:/home/lm/code/nfswork共享的目录 第二个字段:指定哪些用户可以访问 ​ * 表示所有用户都可以访…

第四、五章补充:线代本质合集(B站:小崔说数)

视频1:线性空间 原视频:【线性代数的本质】向量空间、基向量的几何解释_哔哩哔哩_bilibili 很多同学在学习线性代数的时候,会遇到一个困扰,就是不知道什么是线性空间。因为中文的教材往往对线性空间的定义是非常偏数学的&#x…

JS进阶--JS听到了不灭的回响

作用域 作用域(scope)规定了变量能够被访问的“范围”,离开了这个“范围”变量便不能被访问 作用域分为局部和全局 局部作用域 局部作用域分为函数和块 那 什么是块作用域呢? 在 JavaScript 中使用 { } 包裹的代码称为代码块…

MFC读写文件实例

程序功能:点击写入文件按钮将输入编辑框中内容写入以系统时间命名的文件中,点击读取文件按钮将选中的文件内容显示到静态文本控件中。 相关代码如下: void CWR_FILEDlg::OnButton1() {CString str;GetDlgItem(IDC_EDIT1)->GetWindowText…

IWOA-GRU和GRU时间序列预测(改进的鲸鱼算法优化门控循环单元)

时序预测 | MATLAB实现IWOA-GRU和GRU时间序列预测(改进的鲸鱼算法优化门控循环单元) 目录 时序预测 | MATLAB实现IWOA-GRU和GRU时间序列预测(改进的鲸鱼算法优化门控循环单元)预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 MATLAB实现IWOA-GRU和GRU时间序列预测…

详细全面讲解C++中重载、隐藏、覆盖的区别

文章目录 总结1、重载示例代码特点1. 模板函数和非模板函数重载2. 重载示例与调用规则示例代码调用规则解释3. 特殊情况与注意事项二义性问题 函数特化与重载的交互 2. 函数隐藏(Function Hiding)概念示例代码特点 3. 函数覆盖(重写&#xff…

DAY15 神经网络的参数和变量

DAY15 神经网络的参数和变量 一、参数和变量 在神经网络中,参数和变量是两个关键概念,它们分别指代不同类型的数据和设置。 参数(Parameters) 定义:参数是指在训练过程中学习到的模型内部变量,这些变量…

git的rebase和merge的区别?

B分支从A分支拉出 1.git merge 处于A分支执行,git merge B分支:相当于将commit X、commit Y两次提交,作为了新的commit Z提交到了A分支上。能溯源它真正提交的信息。 2.git rebase 处于B分支,执行git rebase A分支,B分支那边复…

2、蓝牙打印机点灯-GPIO输出控制

1、硬件 1.1、看原理图 初始状态位高电平. 需要驱动PA1输出高低电平控制PA1. 1.2、看手册 a、系统架构图 GPIOA在APB2总线上。 b、RCC使能 GPIOA在第2位。 c、GPIO寄存器配置 端口:PA1 模式:通用推挽输出模式 -- 输出0、1即可 速度:5…

使用强化学习训练神经网络玩俄罗斯方块

一、说明 在 2024 年暑假假期期间,Tim学习并应用了Q-Learning (一种强化学习形式)来训练神经网络玩简化版的俄罗斯方块游戏。在本文中,我将详细介绍我是如何做到这一点的。我希望这对任何有兴趣将强化学习应用于新领域的人有所帮助…

基于springboot的网上商城购物系统

作者:学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等 文末获取“源码数据库万字文档PPT”,支持远程部署调试、运行安装。 目录 项目包含: 开发说明: 系统功能: 项目截图…

API架构风格的深度解析与选择策略:SOAP、REST、GraphQL与RPC

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

【网络协议】开放式最短路径优先协议OSPF详解(四)

前言 在本章的第一部分和第二部分中&#xff0c;我们探讨了OSPF的基本配置&#xff0c;并进一步学习了更多OSPF的概念&#xff0c;例如静态路由的重分发及其度量值。在第三部分中&#xff0c;我们讨论了多区域OSPF。在第四部分中&#xff0c;我们将关注OSPF与多访问网络&#…

上门按摩系统架构与功能分析

一、系统架构 服务端&#xff1a;Java&#xff08;最低JDK1.8&#xff0c;支持JDK11以及JDK17&#xff09;数据库&#xff1a;MySQL数据库&#xff08;标配5.7版本&#xff0c;支持MySQL8&#xff09;ORM框架&#xff1a;Mybatis&#xff08;集成通用tk-mapper&#xff0c;支持…

攻防世界 ics-07

点击之后发现有个项目管理能进&#xff0c;点进去&#xff0c;点击看到源码&#xff0c;如下三段 <?php session_start(); if (!isset($_GET[page])) { show_source(__FILE__); die(); } if (isset($_GET[page]) && $_GET[page] ! index.php) { include(flag.php);…

Spring Boot教程之四十九:Spring Boot – MongoRepository 示例

Spring Boot – MongoRepository 示例 Spring Boot 建立在 Spring 之上&#xff0c;包含 Spring 的所有功能。由于其快速的生产就绪环境&#xff0c;使开发人员能够直接专注于逻辑&#xff0c;而不必费力配置和设置&#xff0c;因此如今它正成为开发人员的最爱。Spring Boot 是…