扩散模型条件生成——Classifier Guidance和Classifier-free Guidance原理解析

1、前言

从讲扩散模型到现在。我们很少讲过条件生成(Stable DIffusion曾提到过一点),所以本篇内容。我们就来具体讲一下条件生成。这一部分的内容我就不给原论文了,因为那些论文并不只讲了条件生成,还有一些调参什么的。并且推导过程也相对复杂。我们从一个比较简单的角度出发。

参考论文:Understanding Diffusion Models: A Unified Perspective (arxiv.org)

参考代码:

classifier guidance:GitHub - openai/guided-diffusion

classifier-free guidance:GitHub - coderpiaobozhe/classifier-free-diffusion-guidance-Pytorch: a simple unofficial implementation of classifier-free diffusion guidance

视频:[扩散模型条件生成——Classifier Guidance和Classifier-free Guidance原理解析-哔哩哔哩]

2、常用的条件生成方法

在diffusion里面,如何进行条件生成呢?我们不妨回忆一下在Stable Diffusion里面的一个常用做法。即在训练的时候。给神经网络输入一个条件。
L = ∣ ∣ ϵ − ϵ θ ( x t , t , y ) ∣ ∣ 2 L=||\epsilon-\epsilon_{\theta}(x_t,t,y)||^2 L=∣∣ϵϵθ(xt,t,y)2
里面的y就是条件。至于为什么有效,请看我之前写过的Stable DIffusion那篇文章。在此不过多赘述了。我们来讲这种方法所存在的问题。

很显然的,这种训练的方式,会有一个问题,那就是神经网络或许会学会忽略或者淡化掉我们输入的条件信息。因为就算我们不输入信息,他也照样能够生成。

接下来我们来讲两种更为流行的方法——分类指导器(Classifier Guidance) 和无分类指导器( Classifier-Free Guidance)

3、Classifier Guidance

为了简单起见。我们从分数模型的角度出发。

回忆一下在SDE里面的结论。其反向过程为
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t + g ( t ) d w ˉ (1) \mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{1} dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dwˉ(1)
如果施加条件的话,还是根据Reverse-time diffusion equation models - ScienceDirect这篇论文,可得条件生成时的反向SDE为
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ∣ y ) ] d t + g ( t ) d w ˉ (2) \mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x|y)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{2} dx=[f(x,t)g(t)2xlogpt(xy)]dt+g(t)dwˉ(2)
我们利用贝叶斯公式,对 ∇ x log ⁡ p t ( x ∣ y ) \nabla x \log p_t(x|y) xlogpt(xy)进行处理
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( y ∣ x ) p t ( x ) p t ( y ) = ∇ x ( log ⁡ p t ( y ∣ x ) + log ⁡ p t ( x ) − log ⁡ p t ( y ) ) = ∇ x log ⁡ p t ( x ) + ∇ x log ⁡ p t ( y ∣ x ) \begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log\frac{p_t(y|x)p_t(x)}{p_t(y)}\\=&\nabla_x\left(\log p_t(y|x)+\log p_t(x)-\log p_t(y)\right)\\=&\nabla_x \log p_t(x)+\nabla_x\log p_t(y|x)\end{aligned}\nonumber xlogpt(xy)===xlogpt(y)pt(yx)pt(x)x(logpt(yx)+logpt(x)logpt(y))xlogpt(x)+xlogpt(yx)
第二个等号到第三个等号是因为对 log ⁡ p t ( y ) \log p_t(y) logpt(y)关于x求梯度等于0( log ⁡ p t ( y ) \log p_t(y) logpt(y)与x无关)

把它代入Eq.(2)可得
d x = [ f ( x , t ) − g ( t ) 2 ( ∇ x log ⁡ p t ( x ) + ∇ x log ⁡ p t ( y ∣ x ) ) ] d t + g ( t ) d w ˉ (3) \mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\left(\nabla_x\log p_t(x)+\nabla_x\log p_t(y|x)\right)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{3} dx=[f(x,t)g(t)2(xlogpt(x)+xlogpt(yx))]dt+g(t)dwˉ(3)
对比Eq.(1)和Eq.(3)。我们不难发现,它们的差别,居然是只多了一个 ∇ x log ⁡ p t ( y ∣ x ) \nabla_x\log p_t(y|x) xlogpt(yx)

p t ( y ∣ x ) p_t(y|x) pt(yx)是什么?是以 x x x作为条件,时间为t对应条件y的概率。我们可以怎么求呢?该怎么求出来呢?

当然是使用神经网络了。也就是说,我们可以额外设定一个神经网络,该神经网络输入是 x t x_t xt,输出是条件为y的概率

所以,实际上我们现在需要训练两部分,一部分是 ∇ x log ⁡ p t ( x ) \nabla_x\log p_t(x) xlogpt(x),这我们在SDE中已经讲过该如何训练了。

另一个就是 ∇ x log ⁡ p t ( y ∣ x ) \nabla_x\log p_t(y|x) xlogpt(yx),他就是一个分类神经网络网络。训练好之后,我们就可以使用Eq.(3)通过不同的数值求解器,进行优化了。

作者在此基础上,又引入了一个控制参数 λ \lambda λ
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + λ ∇ x log ⁡ p t ( y ∣ x ) (4) \nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(y|x)\tag{4} xlogpt(xy)=xlogpt(x)+λxlogpt(yx)(4)
λ = 0 \lambda=0 λ=0,表示不加入任何条件。当 λ \lambda λ很大时,模型会产生大量附带条件信息的样本。

这种方法的一个缺点就是,需要额外学习一个分类器 p t ( y ∣ x ) p_t(y|x) pt(yx)

4、Classifier-Free Guidance

之前推出
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + ∇ x log ⁡ p t ( y ∣ x ) (5) \nabla_x \log p_t(x|y)=\nabla_x \log p_t(x)+\nabla_x\log p_t(y|x)\tag{5} xlogpt(xy)=xlogpt(x)+xlogpt(yx)(5)
把该式子代入Eq.(4)可得
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + λ ( ∇ x log ⁡ p t ( x ∣ y ) − ∇ x log ⁡ p t ( x ) ) = ∇ x log ⁡ p t ( x ) + λ ∇ x log ⁡ p t ( x ∣ y ) − λ ∇ x log ⁡ p t ( x ) = ( 1 − λ ) ∇ x log ⁡ p t ( x ) + λ ∇ x log ⁡ p t ( x ∣ y ) \begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log p_t(x)+\lambda\left(\nabla_x\log p_t(x|y)-\nabla_x\log p_t(x)\right)\\=&\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(x|y)-\lambda\nabla_x\log p_t(x)\\=&\left(1-\lambda\right)\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(x|y)\end{aligned}\nonumber xlogpt(xy)===xlogpt(x)+λ(xlogpt(xy)xlogpt(x))xlogpt(x)+λxlogpt(xy)λxlogpt(x)(1λ)xlogpt(x)+λxlogpt(xy)
此时我们注意到,当 λ = 0 \lambda=0 λ=0是,第二项完全为0,会忽略掉条件;当 λ = 1 \lambda=1 λ=1时,使用第二项,第二项就是附带有条件情况下的分布分数网络;而当 λ > 1 \lambda> 1 λ>1,模型会优化考虑条件生成样本,并且远离第一项的无条件分数网络的方向,换句话说,它降低了生成不使用条件信息的样本的概率,而有利于生成明确使用条件信息的样本。

事实上,如果你看了free-Classifier Guidance这篇论文,会发现我们的结论不一样。

其实论文里面的控制参数是 w w w,也就是说,Eq.(4)就变成了这样
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + w ∇ x log ⁡ p t ( y ∣ x ) \nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+w\nabla_x\log p_t(y|x) xlogpt(xy)=xlogpt(x)+wxlogpt(yx)
我们把控制参数改成 1 + w 1+w 1+w不会有任何影响
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + ( 1 + w ) ∇ x log ⁡ p t ( y ∣ x ) \nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+(1+w)\nabla_x\log p_t(y|x) xlogpt(xy)=xlogpt(x)+(1+w)xlogpt(yx)
把Eq.(5)代入该式子
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + ( 1 + w ) ( ∇ x log ⁡ p t ( x ∣ y ) − ∇ x log ⁡ p t ( x ) ) = ∇ x log ⁡ p t ( x ) + ( 1 + w ) ∇ x log ⁡ p t ( x ∣ y ) − ( 1 + w ) ∇ x log ⁡ p t ( x ) = ( 1 + w ) ∇ x log ⁡ p t ( x ∣ y ) − w ∇ x log ⁡ p t ( x ) (6) \begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log p_t(x)+(1+w)\left(\nabla_x\log p_t(x|y)-\nabla_x\log p_t(x)\right)\\=&\nabla_x\log p_t(x)+(1+w)\nabla_x\log p_t(x|y)-(1+w)\nabla_x\log p_t(x)\\=&(1+w)\nabla_x\log p_t(x|y)-w\nabla_x\log p_t(x)\end{aligned}\tag{6} xlogpt(xy)===xlogpt(x)+(1+w)(xlogpt(xy)xlogpt(x))xlogpt(x)+(1+w)xlogpt(xy)(1+w)xlogpt(x)(1+w)xlogpt(xy)wxlogpt(x)(6)
这就是原论文里面的结论。

那么接下来,我们来探讨一下该如何去训练。

对于 ∇ x log ⁡ p t ( x ) \nabla_x\log p_t(x) xlogpt(x),这个不用说了,之前我们训练的就是这个;如何计算 ∇ x log ⁡ p t ( x ∣ y ) \nabla_x\log p_t(x|y) xlogpt(xy)呢,它实际上就是在给定y的情况下,求出 p t ( x ∣ y ) p_t(x|y) pt(xy)。那我们可以怎么做呢?

在NCSN,我们是使用一个加噪分布 q ( x ~ ∣ x ) q(\tilde x|x) q(x~x)取代 p ( x ) p(x) p(x),而从让它是可解的。

对于 p t ( x ∣ y ) p_t(x|y) pt(xy),即便是加多了一个条件之后,我们仍然建模为 q ( x ~ ∣ x ) q(\tilde x|x) q(x~x),也就是说,我们仍然把它建模成一个正向加噪过程。因此,无论是否增加条件。最终的损失函数结果都是
L = ∣ ∣ s θ − ∇ x log ⁡ q ( x ~ ∣ x ) ∣ ∣ 2 = ∣ ∣ s θ − ∇ x log ⁡ q ( x t ∣ x 0 ) ∣ ∣ 2 L=||s_\theta-\nabla_x\log q(\tilde x|x)||^2=||s_\theta-\nabla_x\log q(x_t|x_0)||^2 L=∣∣sθxlogq(x~x)2=∣∣sθxlogq(xtx0)2
后者是通过SDE统一的结果(我在SDE那一节讲过)

那该如何体现条件y呢?其实我们在第二节的时候已经说过了,就是在里面神经网络的输出加入一个条件y。
L = ∣ ∣ s θ ( x t , t , y ) − ∇ x log ⁡ q ( x t ∣ x 0 ) ∣ ∣ 2 (7) L=||s_\theta(x_t,t,y)-\nabla_x\log q(x_t|x_0)||^2\tag{7} L=∣∣sθ(xt,t,y)xlogq(xtx0)2(7)
而不施加条件的时候,长这样
L = ∣ ∣ s θ ( x t , t ) − ∇ x log ⁡ q ( x t ∣ x 0 ) ∣ ∣ 2 (8) L=||s_\theta(x_t,t)-\nabla_x\log q(x_t|x_0)||^2\tag{8} L=∣∣sθ(xt,t)xlogq(xtx0)2(8)
由Eq.(5)可知,我们需要训练两种情况,一种是有条件的,对应Eq.(7);另外一种是无条件的,对应Eq.(8)。

理论上,我们其实也是要训练两个神经网络。但实际上,我们可以把他们结合成一种神经网络。

具体操作就是把无条件的情况作为一种特例。

当我们训练有条件的神经网络的时候,会照样把条件输入进网络里面。而训练无条件的时候,我们构造一个无条件的标识符,把它作为条件输入给神经网络,比如对于所有无条件的情况,我都构造一个0作为条件输入到神经网络里面。通过这种方式,我们就可以把两个网络变成一个网络了,

对于损失函数,直接使用Eq.(7)。我们在SDE里面讲过 ∇ x log ⁡ p ( x ) = − 1 σ ϵ \nabla_x \log p(x)=-\frac{1}{\sigma}\epsilon xlogp(x)=σ1ϵ。所以我们最终我们把预测噪声,变成了预测分数。我们同样可以把它变回来,变成预测分数
L = ∣ ∣ ϵ − ϵ θ ( x t , t , y ) ∣ ∣ 2 L=||\epsilon-\epsilon_{\theta}(x_t,t,y)||^2 L=∣∣ϵϵθ(xt,t,y)2
所以损失函数就变成了这样。在训练的时候,作者设定一个大于等于0,小于等于1的超参数 p u n c o n d p_{uncond} puncond,它的作用就是判断是否需要输入条件(从0-1分布采样一个值,大于 p u n c o n d p_{uncond} puncond则使用条件,反之则不使用)。也就是说,这相当于dropout一样,随机舍弃掉一些条件,把他们作为无条件的情况(因为我们既要学习有条件的,又要学习无条件的)。所以,最终的训练过程就是这样

在这里插入图片描述

其中里面的 λ \lambda λ你就当作是时刻t吧(其实不是,其实是时刻t的噪声(噪声的初始化不一样,不是传统的等差数列,是用三角函数初始化的)。由于与本篇内容无关,故而忽略),c是条件。

同样的,采用过程使用Eq.(6)的结构进行采样

在这里插入图片描述

5、结束

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/344578.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习笔记——路由网络基础——静态路由(static)

三、静态路由(static) 1、静态路由 (1)定义 静态路由(Static):由管理员手动配置和维护的路由。静态路由配置简单,被广泛应用于网络中。此外还可以实现负载均衡和路由备份。 静态路由默认优先级为60,如果想在多条静态路由中让某条路由优选…

Unity基础实践小项目

项目流程: 需求分析 开始界面 选择角色面板 排行榜面板 设置面板 游戏面板 确定退出面板 死亡面板 UML类图 准备工作 1.导入资源 2.创建需要的文件夹 3.创建好面板基类 开始场景 开始界面 1.拼面板 2.写脚本 注意事项:注意先设置NGUI的分辨率大小&…

Scala 练习一 将Mysql表数据导入HBase

Scala 练习一 将Mysql表数据导入HBase 续第一篇:Java代码将Mysql表数据导入HBase表 源码仓库地址:https://gitee.com/leaf-domain/data-to-hbase 一、整体介绍二、依赖三、测试结果四、源码 一、整体介绍 HBase特质 连接HBase, 创建HBase执行对象 初始化…

从0到1:企业办公审批小程序开发笔记

可行性分析 企业办公审批小程序,适合各大公司,企业,机关部门办公审批流程,适用于请假审批,报销审批,外出审批,合同审批,采购审批,入职审批,其他审批等规划化…

使用 stress 命令进行Linux CPU 压力测试

大家好,在现代计算机系统中,对系统性能和稳定性的评估是至关重要的。特别是在服务器环境中,我们需要确保系统能够在高负载情况下稳定运行,以满足用户的需求。而 CPU 是系统中最关键的组件之一,其性能直接影响着整个系统…

用 DataGridView 控件显示数据

使用DataGridView,可以很方便显示数据。 (1)Visual Studio版本:Visual Studio 2022 (2)应用程序类型:windows form (3)编程语言:C# 一、目标框架 .NET Fra…

【NI国产替代】高速数据采集模块,最大采样率为 125 Msps,支持 FPGA 定制化

• 双通道高精度数据采集 • 支持 FPGA 定制化 • 双通道高精度采样率 最大采样率为 125 Msps12 位 ADC 分辨率 最大输入电压为 0.9 V -3 dB 带宽为 30 MHz 支持 FPGA 定制化 根据需求编程实现特定功能和性能通过定制 FPGA 实现硬件加速,提高系统的运算速度FPGA…

Docker中搭建likeadmin

一、使用Docker中的docker-compose搭建likeadmin 1.去网址:https://gitee.com/likeadmin/likeadmin_php中下载likeadmin 注册一个giee账号后 点那个克隆下载 按照序号在终端复制粘贴进去。 接着,输入ls 可以发现有一个这个: 里面有一个like…

服务器数据恢复—服务器raid5上层zfs文件系统数据恢复案例

服务器数据恢复环境&故障: 一台某品牌X3650M3服务器,服务器中有一组raid5磁盘阵列,上层采用zfs文件系统。 服务器未知原因崩溃,工作人员排查故障后发现服务器的raid5阵列中有两块硬盘离线导致该阵列不可用,服务器内…

Cell-在十字花科植物中年生和多次开花多年生开花行为的互相转化-文献精读21

Reciprocal conversion between annual and polycarpic perennial flowering behavior in the Brassicaceae 在十字花科植物中年生和多次开花多年生开花行为的互相转化 亮点 喜马拉雅须弥芥 和 内华达糖芥 是两个多年生植物模型 MADS-box 基因的剂量效应决定了一年生、二年生…

NodeJs实现脚本:将xlxs文件输出到json文件中

文章目录 前期工作和依赖笔记功能代码输出 最近有一个功能,将json文件里的内容抽取到一个xlxs中,然后维护xlxs文件。当要更新json文件时,就更新xlxs的内容并把它传回json中。这个脚本主要使用NodeJS写。 以下是完成此功能时做的一些笔记。 …

Oracle EBS AP发票创建会计科目错误:子分类帐日记帐分录未按输入币种进行平衡

系统版本 RDBMS : 12.1.0.2.0 Oracle Applications : 12.2.6 问题症状: 提交“创建会计科目”请求提示错误信息如下: 中文报错: 该子分类帐日记帐分录未按输入币种进行平衡。请检查日记帐分录行中输入的金额。 英文报错:The subledger journal entry does not balance i…

11 IP协议 - IP协议头部

什么是 IP 协议 IP(Internet Protocol)是一种网络通信协议,它是互联网的核心协议之一,负责在计算机网络中路由数据包,使数据能够在不同设备之间进行有效的传输。IP协议的主要作用包括寻址、分组、路由和转发数据包&am…

【Python教程】1-注释、变量、标识符与基本操作

在整理自己的笔记的时候发现了当年学习python时候整理的笔记,稍微整理一下,分享出来,方便记录和查看吧。个人觉得如果想简单了解一名语言或者技术,最简单的方式就是通过菜鸟教程去学习一下。今后会从python开始重新更新&#xff0…

使用OpenCV dnn c++加载YOLOv8生成的onnx文件进行实例分割

在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 EISeg 工具进行标注,然后使用 eiseg2yolov8 脚本将.json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下…

【UML用户指南】-05-对基本结构建模-类

目录 1、名称(name) 2、属性 (attribute) 3、操作(operation) 4、对属性和操作的组织 4.1、衍型 4.2、职责 (responsibility) 4.3、其他特征 4.4、对简单类型建模 5、结构良…

【Mtk Camera开发学习】06 MTK 和 Qcom 平台支持通过 Camera 标准API 打开 USBCamera

本专栏内容针对 “知识星球”成员免费,欢迎关注公众号:小驰行动派,加入知识星球。 #MTK Camera开发学习系列 #小驰私房菜 Google 官方介绍文档: https://source.android.google.cn/docs/core/camera/external-usb-cameras?hlzh-…

【传知代码】DETR[端到端目标检测](论文复现)

前言:想象一下,当自动驾驶汽车行驶在繁忙的街道上,DETR能够实时识别出道路上的行人、车辆、交通标志等目标,并准确预测出它们的位置和轨迹。这对于提高自动驾驶的安全性、减少交通事故具有重要意义。同样,在安防监控、…

Proxyman 现代直观的 HTTP 调试代理应用程序

Proxyman 是一款现代而直观的 HTTP 调试代理应用程序,它的功能强大,使您可以轻松捕获、检查和操作 HTTP(s) 流量。不再让繁杂的网络调试工具阻碍您的工作,使用 Proxyman,您将轻松应对网络调试的挑战。 下载地址:https…

BeagleBone Black入门总结

文章目录 参考连接重要路径系统镜像下载访问 BeagleBone 参考连接 镜像下载启动系统制作:SD卡烧录工具入门书籍推荐:BeagleBone cookbookBeagleBon cookbook 例程BeagleBone概况?BeagleBone 官方管理仓库(原理图,官方例程。。。)…