【笔记】扩散模型(九):Imagen 理论与实现

论文链接:Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding

非官方实现:lucidrains/imagen-pytorch

Imagen 是 Google Research 的文生图工作,这个工作并没有沿用 Stable Diffusion 的架构,而是级联了一系列普通的 DDPM 模型。其主要的贡献有以下几个方面:

  1. 使用比较大的文本模型进行文本嵌入,可以获得比使用 CLIP 更好的文本理解能力;
  2. 在采样阶段引入了一种动态阈值的方法,可以利用更高的 guidance scale 来生成更真实、细节更丰富的图像(这里的阈值是控制 x \mathbf{x} x 的范围);
  3. 改良了 UNet,提出 Efficient UNet,使模型更简单、收敛更快、内存消耗更少。

该模型的架构如下图所示,可以看到使用了一个条件生成的 diffusion 模型以及两个超分辨率模型,每个模型都以文本模型的 embedding 作为条件,先生成一个 64 分辨率的图像,然后逐步超分辨率到 1024 大小。

Imagen 模型结构

Imagen

预训练文本模型

现在的文生图模型主流使用的文本嵌入方法是使用 CLIP 文本编码器,在直观上感觉是比较合理的,因为 CLIP 的文本特征和图像特征共享同一个空间,用来控制图像的生成过程是比较合理的。不过 CLIP 的缺点是对文本的表达能力比较有限,处理复杂文本比较困难。

这里选择的不是使用 CLIP,而是使用规模比较大、且在大规模文本语料上训练的文本模型,具体来说使用的模型有 BERT、T5 和 CLIP。经过实验(具体结果可以看原论文 Figure 4 的 a 和 b,以及 Figure A.5),主要有以下发现:

  • 缩放文本编码器对提升生成质量的作用很明显;
  • 相比增大 UNet 的尺寸,增大文本编码器的尺寸更重要;
  • 相比于 CLIP,人类更偏好 T5-XXL 的结果。

高 Guidance Scale 的改善

提高 classifier-free guidance 的 guidance scale 可以提升文本-图像的匹配程度,但是会破坏图像的质量。这个现象是因为高 guidance scale 会导致训练阶段和测试阶段出现 mismatch。具体来说,在训练时,所有的 x \mathbf{x} x 都分布在 [ − 1 , 1 ] [-1,1] [1,1] 的范围里,然而当使用比较大的 guidance scale 时,得到的 x \mathbf{x} x 会超出这个范围。这样会导致 x \mathbf{x} x 落在已经学习过的范围以外,为了解决这个问题,作者研究了静态阈值(static thresholding)和动态阈值(dynamic thresholding)两种方案,具体算法如下图所示:

静态阈值和动态阈值算法

静态阈值

这种方法就是在预测噪声后,先计算出 x 0 \mathbf{x}_0 x0,然后将其取值范围直接裁剪到 [ − 1 , 1 ] [-1,1] [1,1] 之间,然后再进行去噪。这种方法已经很多方法都使用了,例如 openai/guided-diffusion 中的这段代码就是为了进行这种处理:

def process_xstart(x):if denoised_fn is not None:x = denoised_fn(x)if clip_denoised:return x.clamp(-1, 1) # 裁剪到 [-1,1]return xif self.model_mean_type == ModelMeanType.EPSILON:pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) # 得到 x_0)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t
)

动态阈值

这个方法不是很好理解,我们可以从一个例子出发,我们平时进行 classifier-free guidance 时使用的 guidance scale 通常都是 7.5,那么一个原本分布在 [ − 1 , 1 ] [-1,1] [1,1] 之间的变量乘以这个系数之后就会变到 [ − 7.5 , 7.5 ] [-7.5,7.5] [7.5,7.5] 的范围内。如果某处的几个数分别是 { 0.2 , 0.4 , 0.6 , 0.8 } \{0.2, 0.4, 0.6, 0.8\} {0.2,0.4,0.6,0.8},乘以 7.5 后就变成了 { 1.5 , 3.0 , 4.5 , 6.0 } \{1.5,3.0,4.5,6.0\} {1.5,3.0,4.5,6.0}。如果此时直接将这些数裁剪到 [ − 1 , 1 ] [-1,1] [1,1],那么所有的数都会变成 1,原本这些数之间是有比较大的差别的,裁剪后都变成了相同的数,这样很明显是不合理的,动态阈值就是为了寻找一个比较合理的裁剪范围。

这里的做法是寻找一个 x 0 \mathbf{x}_0 x0 的 p-分位数 s s s,也就是找到大多数的数字落在什么范围内,然后先裁剪到 [ − s , s ] [-s,s] [s,s] 范围内,再全部除以 s s s 以缩放到 [ − 1 , 1 ] [-1,1] [1,1] 的范围内。实验发现这种方法能比较好地改善图像的质量,这部分的代码如下所示(摘自非官方实现):

if pred_objective == 'noise':x_start = noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
elif pred_objective == 'x_start':x_start = pred
elif pred_objective == 'v':x_start = noise_scheduler.predict_start_from_v(x, t=t, v=pred)if dynamic_threshold: # 动态阈值# 找到 p-分位数s = torch.quantile(rearrange(x_start, 'b ... -> b (...)').abs(),self.dynamic_thresholding_percentile,dim = -1)s.clamp_(min=1.)s = right_pad_dims_to(x_start, s)# 进行归一化x_start = x_start.clamp(-s, s) / s
else: # 静态阈值,直接截断x_start.clamp_(-1., 1.)
mean_and_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next)

级联扩散模型

为了生成高分辨率图像,模型级联了三个扩散模型,一个用来生成低分辨率图像,两个用来将低分辨率图像逐步超分到高分辨率。在训练阶段,作者发现使用带有噪声条件增强的超分模型可以生成更高质量的模型。具体来说,每次生成噪声时,还从 [ 0 , 1 ] [0,1] [0,1] 范围内随机采样一个 aug level,然后基于这个 level 进行增强。在预测噪声时,不仅输入带噪声的图像、低分辨率图像、时间步,还输入一个 aug level。在推理阶段,使用一系列 aug level 进行增强,然后分别进行推理,从中选取一个最佳样本,这样可以提升采样效果。具体的算法如下所示:

超分模型的训练和采样过程

总结

除了上述的一些贡献,Imagen 还做了一些工程上的改进,例如使用了不同的 text condition 注入方式,以及对基础的 UNet 模型进行了改进,提出了 Efficient UNet 模型等。相比同期的其他方法,Imagen 应该是为数不多可以直接生成 1024 分辨率图像的 diffusion 模型,虽然和主流的 Stable Diffusion 架构不同,但其中的一些改进思路还是值得学习一下的。

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(九):Imagen 理论与实现,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/467678.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

css:基础

前言 我们之前其实也可以写出一个看起来算是一个网页的网页,为什么我们还要学css? CSS(Cascading Style Sheets)也叫层叠样式表,是负责美化的,我们之前说html就是一个骨架,css就可以用来美化网…

[全网最细数据结构完整版]第七篇:3分钟带你吃透队列

目录 1->队列的概念及结构 2->队列的实现 2.1定义队列基本结构 struct QueueNode 和 struct Queue 2.2队列初始化函数 QueueInit 函数 2.3队列销毁函数 QueueDestroy 函数 2.4队列插入数据函数 QueuePush 函数 2.5判断队列是否为空,空返回true,非空返回false 2.6队列删…

Android笔记(三十五):用责任链模式封装一个App首页Dialog管理工具

背景 项目需要在首页弹一系列弹窗,每个弹窗是否弹出都有自己的策略,以及哪个优先弹出,哪个在上一个关闭后再弹出,为了更好管理,于是封装了一个Dialog管理工具 效果 整体采用责任链模块设计,控制优先级及弹…

掌握软件组件/单元测试中的这些术语,你就算正式入门了

上篇干货,和大家分享了软件测试的几个级别,在【组件/单元测试】当中,涉及不少名词术语。从之前的学员学习过程来看,这里比较容易出现概念混乱,进而导致面试过程中频频翻车,所以有必要在这里单独拎出来和大家…

html的week控件 获取周(星期)的第一天(周一)和最后一天(周日)

html的week控件 获取周(星期)的第一天(周一)和最后一天(周日) <input type"week" id"week" class"my-css" value"ViewBag.DefaultWeek" /><script> function PageList() { var dateStrin…

【主机游戏】艾尔登法环游戏攻略

艾尔登法环&#xff0c;作为一款备受好评但优化问题频发的游戏&#xff0c;就连马斯克都夸过 今天介绍一下这款游戏 https://pan.quark.cn/s/24760186ac0b 角色升级 在《艾尔登法环》中&#xff0c;角色升级需要找到梅琳娜。你可以在关卡前废墟的营地附近&#xff0c;风暴关…

CSS 中三角形的绘制方法详解

在网页设计领域&#xff0c;特殊形状常常能为页面增添独特的视觉效果&#xff0c;三角形便是其中之一。本文将详细介绍如何利用 CSS 绘制三角形。 一、原理阐述 CSS 中一个元素的边框分为上边框、右边框、下边框和左边框。当把一个元素的宽度和高度设为 0&#xff0c;且只让其…

虚拟机linux7.9下安装mysql

1.MySQL官网下载安装包&#xff1a; MySQL :: Download MySQL Community Server https://cdn.mysql.com/archives/mysql-5.7/mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz 2.解压文件&#xff1a; #tar xvzf mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz 3.移动文件&#…

负载均衡式在线oj项目开发文档(个人项目)

项目目标 需要使用的技术栈&#xff1a; 这个项目共分成三个模块第一个模块为公共的模块&#xff0c;用于解决字符串处理&#xff0c;文件操作&#xff0c;网络连接等等的问题。 第二个模块是一个编译运行的模块&#xff0c;这个模块的主要功能就是将用户的代码收集上来之后要…

MySQL数据库专栏(五)连接MySQL数据库C API篇

摘要 本篇文章主要介绍通过C语言API接口链接MySQL数据库&#xff0c;各接口功能及使用方式&#xff0c;辅助类的封装及调用实例&#xff0c;可以直接移植到项目里面使用。 目录 1、环境配置 1.1、添加头文件 1.2、添加库目录 2、接口介绍 2.1、MySql初始化及数据清理 2.1.…

PH热榜 | 2024-11-08

DevNow 是一个精简的开源技术博客项目模版&#xff0c;支持 Vercel 一键部署&#xff0c;支持评论、搜索等功能&#xff0c;欢迎大家体验。 在线预览 1. Quorini 标语&#xff1a;几分钟内设计并运行无服务器云 API 介绍&#xff1a;Quorini 提供了一套可视化的工具&#xff…

QML:Menu详细使用方法

目录 一.性质 二.作用 三.方法 四.使用 1.改变标签 2.打开本地文件 3.退出程序 4.打开Dialog 五.效果 六.代码 在 QML 中&#xff0c;Menu 是一个用于创建下拉菜单或上下文菜单的控件。它通常由多个 MenuItem 组成&#xff0c;每个 MenuItem 可以包含文本、图标和快捷…

k8s 处理namespace删除一直处于Terminating —— 筑梦之路

问题现象 k8s集群要清理某个名空间&#xff0c;把该名空间下的资源全部删除后&#xff0c;删除名空间&#xff0c;一直处于Terminating状态&#xff0c;无法完全清理掉。 如何处理 为什么要记录下这个处理的步骤&#xff0c;经过查询资料&#xff0c;网上也有各种各样的方法&…

>>,<<,~,,|,∧

‌监视器中的数值在十六进制显示时没有负数&#xff0c;主要是因为十六进制本身不直接表示负数&#xff0c;而是通过补码的形式来表示。

【韩老师零基础30天学会Java 】03章 变量

第三章 变量 1. 变量介绍 为什么需要变量&#xff1f; 变量是程序的基本组成单位 变量有三个基本单位&#xff1a;类型名称值 //1.定义变量int age 20;double score88.6;char gender男;String namejack;变量使用注意事项 变量表示内存中的一个存储区域[不同的变量,类型不同&am…

扭蛋机小程序开发,潮玩扭蛋机市场下新机遇

随着大众对潮玩文化的需求不断增长&#xff0c;市场进行了创新升级&#xff0c;不再局限于传统的销售营销模式&#xff0c;进一步推动行业的发展。目前&#xff0c;扭蛋机的种类越来越丰富&#xff0c;从手办、玩具到各种IP周边等&#xff0c;为市场带来更多新颖的扭蛋商品。销…

Unity 实现数字垂直滚动效果

Unity 实现数字垂直滚动效果 前言项目场景布置Shader代码编写材质球设置代码编写数字图片 前言 遇到一个需要数字垂直滚动模拟老虎机的效果&#xff0c;记录一下。 项目 场景布置 3个Image换上带有RollNumberShader的材质 在RollNumberScript脚本中引用即可 Shader代码编…

记录解决vscode 登录leetcode中遇到的问题

1. 安装完 leetcode 点击sign in to leetcode 点击打开网站登录leetcode&#xff0c;发现网页无法打开。 解决办法&#xff1a;将leetcode.cn.js文件中的leetcode-cn.com路径都改成leetcode.cn 2. 继续点击 sign in to leetcode &#xff0c;选择使用账号登录&#xff0c;始…

设计模式之适配器模式(从多个MQ消息体中,抽取指定字段值场景)

前言 工作到3年左右很大一部分程序员都想提升自己的技术栈&#xff0c;开始尝试去阅读一些源码&#xff0c;例如Spring、Mybaits、Dubbo等&#xff0c;但读着读着发现越来越难懂&#xff0c;一会从这过来一会跑到那去。甚至怀疑自己技术太差&#xff0c;慢慢也就不愿意再触碰这…

万字长文解读深度学习——循环神经网络RNN、LSTM、GRU、Bi-RNN

推荐阅读&#xff1a; 深度学习知识点全面总结 如何从RNN起步&#xff0c;一步一步通俗理解LSTM 深度学习之RNN(循环神经网络) 循环神经网络&#xff08;RNN与LSTM&#xff09; 文章目录 &#x1f33a;深度学习面试八股汇总&#x1f33a;文本特征提取的方法1. 基础方法1.1 词袋…