扩散模型 DDPM 核心代码梳理

参考内容:

大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
AIGC 基础,从VAE到DDPM 原理、代码详解
全网最简单的扩散模型DDPM教程
The Annotated Diffusion Model
LaTeX公式编辑器

备注: 具体公式的推导请查看参考链接,本文只记录核心步骤的几个核心公式。

什么是扩散模型?

与Normalizing Flows、GAN或VAEs等生成模型一样,它们都将噪声从一些简单分布转换为数据样本。这也是使用神经网络学习从纯噪声开始逐渐去噪进行内容生成的过程。扩散模型主要包括以下两个过程:

  • 前向加噪: 前向加噪过程是一个固定的、预定义的过程,通过逐步的往一张真实图像上添加高斯噪声,最终得到一个完全的高斯噪声图像
  • 反向去噪: 反向去噪过程通过训练学习一个神经网络模型,模型的输入是一张带有噪声的图像,模型的输出是预测得到的噪声,逐步减去预测的噪声,最终得到一张真实的图像
    在这里插入图片描述

加噪、去噪、训练、推理阶段相关的数学公式

  • 前向加噪

在前向加噪过程中,逐步的往真实图片上添加高斯噪声,每一步添加高斯噪声的公式表示如下:
x t = 1 − β t x t − 1 + β t ϵ t \begin{equation}x_{t} = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_{t}}\epsilon_{t}\end{equation} xt=1βt xt1+βt ϵt
其中, 0 < β 1 < β 2 < ⋯ < β T < 1 0 < \beta_{1} < \beta_{2} < \dots < \beta_{T} < 1 0<β1<β2<<βT<1 ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵN(0,1) β t \beta_{t} βt的取值可以想神经网络的学习率衰减那样,使用线性的、余弦变化的。由于正态分布的均值和方差具有可加性,从[1, T]时刻逐步添加噪声的过程可以通过一步得到:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ \begin{equation}x_{t} = \sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}\epsilon\end{equation} xt=αtˉ x0+1αtˉ ϵ
其中, α t = 1 − β t \alpha_{t} = 1 - \beta_{t} αt=1βt α t ˉ = α t α t − 1 … α 1 \bar{\alpha_{t}} = \alpha_{t}\alpha_{t-1}\dots\alpha_{1} αtˉ=αtαt1α1

  • 模型训练

在模型训练阶段,对于一个真实的图像数据,随机生成[1, T]之前的整数,表示往真实图片数据中添加噪声的次数,然后将添加噪声后的图片输入到神经网络模型中,预测添加的噪声,基于神经网络预测的噪声和真实添加的噪声,计算损失:
L o s s = ∣ ∣ ϵ − ϵ θ ( α t ˉ x 0 + 1 − α t ˉ ∗ ϵ , t ) ∣ ∣ 2 \begin{equation}Loss = ||\epsilon -\epsilon_{\theta}(\sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}*\epsilon, t)||^{2}\end{equation} Loss=∣∣ϵϵθ(αtˉ x0+1αtˉ ϵ,t)2
其中, ϵ \epsilon ϵ表示在前向加噪过程中,使用公式(2)往真实图片中添加的随机噪声, ϵ θ \epsilon_{\theta} ϵθ表示一个神经网络模型,输入一个带有噪声的图像,以及对应添加噪声的时间步数,输出预测的噪声, x 0 x_{0} x0表示原始的真实图像, t t t表示时间步数。
在这里插入图片描述

  • 反向去噪

在反向去噪过程中,使用神经网络预测输出一个和输入图像一样大小的噪声数据,从输入图像中减去噪声数据,实现去噪。
x t − 1 = 1 α t ( x t − β t β t ˉ ∗ ϵ θ ( x t , t ) ) + δ t ∗ z \begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation} xt1=αt 1(xtβtˉ βtϵθ(xt,t))+δtz
其中, ϵ θ \epsilon _{\theta} ϵθ是一个神经网络模型, ϵ θ ( x t , t ) \epsilon _{\theta }(x_{t},t) ϵθ(xt,t)是神经网络模型预测输出的噪声, β t ˉ = 1 − α t ˉ \bar{\beta_{t}} = 1 - \bar{\alpha_{t}} βtˉ=1αtˉ

  • 模型推理

在模型推理阶段,也就是模型训练完之后进行图像的生成阶段,设置好迭代生成的时间步数 t t t,通过一个随机噪声 x t x_{t} xt,不断执行下面的步骤,直到公式(5)中的 t = 1 t = 1 t=1,实现图像的生成:
x t − 1 = 1 α t ( x t − β t β t ˉ ∗ ϵ θ ( x t , t ) ) + δ t ∗ z \begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation} xt1=αt 1(xtβtˉ βtϵθ(xt,t))+δtz
x t = x t − 1 \begin{equation}x_{t} = x_{t-1}\end{equation} xt=xt1
t = t − 1 \begin{equation}t = t-1\end{equation} t=t1

当公式(5)中的 t = 1 t = 1 t=1时,也就是最后一轮去噪,不加 δ t ∗ z \delta_{t}*z δtz,最后得到的 x 0 x_{0} x0就是生成的图像内容。
在这里插入图片描述

UNet网络结构

UNet神经网络在特定的时间步 t t t 接收噪声图像并返回预测的噪声。预测的噪声是一个与输入图像具有相同的大小/分辨率的张量。从技术上讲,网络输入和输出相同形状的张量。在DDPM中采用UNet架构的神经网络,UNet网络中主要包括以下部分:
在这里插入图片描述

  • 下采样:使用卷积 + 池化的方式实现图像分辨率的下采样
  • 上采样:使用转置卷积或者线性插值的方式,提升特征图的分辨率
  • Short-cut连接:将下采样和上采样得到的分辨率相同额特征图在通道维度上进行融合,有利于捕捉细粒度的图像特征
  • 注意力机制:使用注意力机制计算特征图上每个位置之间的注意力关系
  • time-embedding:由于DDPM是逐步生成图像的,所以需要一个特征能够标记当前执行到哪个时间步了

DDPM核心代码解释

  1. 基础代码:构造 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数
  • 使用不同的策略构建 β \beta β 序列
def linear_beta_schedule(timesteps):"""在0.0001到0.02之间,均匀采样timesteps个数值,构造成beta序列"""beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start, beta_end, timesteps)def cosine_beta_schedule(timesteps, s=0.008):"""cosine schedule as proposed in https://arxiv.org/abs/2102.09672"""steps = timesteps + 1x = torch.linspace(0, timesteps, steps)alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])return torch.clip(betas, 0.0001, 0.9999)def quadratic_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2def sigmoid_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02betas = torch.linspace(-6, 6, timesteps)return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
  • 根据生成的 β \beta β 序列,生成 α , α ˉ , β ˉ \alpha,\bar{\alpha},\bar{\beta} α,αˉ,βˉ等, α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数的序列长度等于最大的迭代步长timesteps
timesteps = 300# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
  • 备注
    • betas对应 β \beta β
    • alphas对应 α = 1 − β \alpha = 1 - \beta α=1β
    • alphas_cumprod对应 α ˉ \bar{\alpha} αˉ
    • sqrt_recip_alphas对应 1 α \frac{1}{\sqrt{\alpha}} α 1
    • sqrt_alphas_cumprod对应 1 α ˉ \frac{1}{\sqrt{\bar{\alpha}}} αˉ 1
    • sqrt_one_minus_alphas_cumprod对应 1 − α ˉ \sqrt{1 - \bar{\alpha}} 1αˉ
  • 在训练阶段对于batch中的每个样本,加噪的迭代次数是从[0, T]中进行随机采样的,所以训练阶段每个样本的加噪次数 t ∈ [ 0 , T ] t \in [0, T] t[0,T] 是不同的,使用gather函数获取到每个样本的t对应的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,对应的代码如下:
def extract(a, t, x_shape):batch_size = t.shape[0]out = a.gather(-1, t.cpu())return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
  1. 前向加噪:根据上一步计算得到的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,将一张真实图像 x 0 x_{0} x0 使用公式(2)进行多次加噪,得到加噪后的图像,对应代码如下:
def q_sample(x_start, t, noise=None):if noise is None:noise = torch.randn_like(x_start)# x_start就是前面讲的最原始图像 x_0,根据 t 获取到对应的alpha,beta等参数sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)# 使用公式(2)对图像进行前向加噪return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
  1. UNet模型:将加噪后的样本以及每个样本对应的加噪次数 t 输入到UNet网络模型中,UNet模型预测输出加入的噪声,将UNet的输出结果与加入到图像中的噪声使用公式(3)计算损失,训练UNet网络模型。
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):if noise is None:noise = torch.randn_like(x_start)# x_start就是前面讲的最原始图像 x_0,这一步就是往 x_0 中加入t次的噪声x_noisy = q_sample(x_start=x_start, t=t, noise=noise)# 将加入噪声的图像以及对应的时间步数 t 输入到UNet模型predicted_noise = denoise_model(x_noisy, t)# 将UNet预测的结果与加入的噪声计算损失if loss_type == 'l1':loss = F.l1_loss(noise, predicted_noise)elif loss_type == 'l2':loss = F.mse_loss(noise, predicted_noise)elif loss_type == "huber":loss = F.smooth_l1_loss(noise, predicted_noise)else:raise NotImplementedError()return loss
  1. 模型推理:当训练完UNet之后,在模型推理也就是图像生成阶段执行反向去噪过程。首先生成一张纯噪声的图像,初始时间步设置为timesteps,将噪声图像和时间步数值 t 输入到UNet模型中,预测得到输出结果,然后使用公式(4)计算得到经过去噪之后 t-1时间步的输出,如此迭代,直到 t=0为止。
def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)# Equation 11 in the paper# Use our model (noise predictor) to predict the meanmodel_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)if t_index == 0:return model_meanelse:posterior_variance_t = extract(posterior_variance, t, x.shape)noise = torch.randn_like(x)# Algorithm 2 line 4:return model_mean + torch.sqrt(posterior_variance_t) * noise # Algorithm 2 (including returning all images)def p_sample_loop(model, shape):device = next(model.parameters()).deviceb = shape[0]# start from pure noise (for each example in the batch)img = torch.randn(shape, device=device)imgs = []for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)imgs.append(img.cpu().numpy())return imgsdef sample(model, image_size, batch_size=16, channels=3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

注意事项:

  • torch.randn生成符合标准正态分布的数据,torch.rand生成符合0-1之间均匀分布的数据
  • UNet有利于细粒度的图像生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/124067.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【聚类】K-Means聚类

cluster&#xff1a;簇 原理&#xff1a; 这边暂时没有时间具体介绍kmeans聚类的原理。简单来说&#xff0c;就是首先初始化k个簇心&#xff1b;然后计算所有点到簇心的欧式距离&#xff0c;对一个点来说&#xff0c;距离最短就属于那个簇&#xff1b;然后更新不同簇的簇心&a…

OpenCV(二十八):连通域分割

目录 1.介绍连通域分割 2.像素领域介绍 3.两遍法分割连通域 4.连通域分割函数 1.介绍连通域分割 连通域分割是一种图像处理技术&#xff0c;用于将图像中的相邻像素组成的区域划分为不同的连通域。这些像素具有相似的特性&#xff0c;如相近的灰度值或颜色。连通域分割可以…

C高级第2天

写一个1.sh脚本&#xff0c;将以下内容放到脚本中&#xff1a; 在家目录下创建目录文件&#xff0c;dir 在dir下创建dir1和dir2 把当前目录下的所有文件拷贝到dir1中&#xff0c; 把当前目录下的所有脚本文件拷贝到dir2中 把dir2打包并压缩为dir2.tar.xz 再把dir2.tar.xz…

Android 12 源码分析 —— 应用层 四(SystemUI的基本布局设计及其基本概念)

Android 12 源码分析 —— 应用层 四&#xff08;SystemUI的基本布局设计及其基本概念&#xff09; 在上两篇文章中&#xff0c;我们介绍SystemUI的启动过程&#xff0c;以及基本的组件依赖关系。基本的依赖关系请读者一定要掌握&#xff0c;因为后面的文章&#xff0c;将会时…

2023年9月惠州/深圳CPDA数据分析师认证找弘博创新

CPDA数据分析师认证是大数据方面的认证&#xff0c;助力数据分析人员打下扎实的数据分析基础知识功底&#xff0c;为入门数据分析保驾护航。 帮助数据分析人员掌握系统化的数据分析思维和方法论&#xff0c;提升工作效率和决策能力&#xff0c;遇到问题能够举一反三&#xff0c…

四川玖璨电子商务有限公司:抖店怎么运营爆款

如今&#xff0c;随着网络的普及和电商平台的兴起&#xff0c;越来越多的人开始关注和尝试开设自己的网店。然而&#xff0c;在面对激烈的市场竞争中&#xff0c;如何让自己的抖店脱颖而出&#xff0c;成为爆款产品的运营者&#xff0c;是每个抖店经营者迫切需要解决的问题。 …

elementUi中的el-table表格的内容根据后端返回的数据用不同的颜色展示

效果图如下&#xff1a; 首先 首先&#xff1a;需要在表格行加入 <template slot-scope"{ row }"> </template>标签 <el-table-column prop"usable" align"center" label"状态" width"180" ><templ…

【业务功能篇91】微服务-springcloud-多线程-线程池执行顺序

一、线程的实现方式 1. 线程的实现方式 1.1 继承Thread class ThreadDemo01 extends Thread{Overridepublic void run() {System.out.println("当前线程:" Thread.currentThread().getName());} }1.2 实现Runnable接口 class ThreadDemo02 implements Runnable{…

20个经典巧妙电路合集

1、防反接保护&#xff08;二极管&#xff09; 在实际电子设计中&#xff0c;防反接保护电路非常重要&#xff0c;不要觉得自己肯定不会接错&#xff0c;实际上无论多么小心&#xff0c;还是会犯错误...... 最简单的就是利用二极管了&#xff0c;利用二极管的单向导电性&#…

米贸搜什么是网站排名流量

当谈到数字营销时&#xff0c;你的网站应该作为线上营销的中心枢纽。包括&#xff1a;Ads付费广告、EDM邮件营销、SEO搜索引擎优化等都旨在吸引用户访问你的网站&#xff0c;并在网站上进行深度转化。 被广泛应用且最有效的营销策略之一就是SEO&#xff0c;流量排名是衡量网站受…

管理类联考——数学——汇总篇——知识点突破——数据分析——计数原理——减法原理除法原理

减法原理 正面难则反着做(“ − - −”号) 【思路】当出现“至少、至多”、“否定用语"等正面较难分类的题目&#xff0c;可以采用反面进行求解&#xff0c;注意部分反面的技巧以及“且、或"的反面用法。 除法原理 看到相同&#xff0c;定序用除法消序( “ &quo…

数据结构--5.0.1图的存储结构

目录 一、邻接矩阵&#xff08;无向图&#xff09; 二、邻接矩阵&#xff08;有向图&#xff09; 三、邻接矩阵&#xff08;网&#xff09; 四、邻接表&#xff08;无向图&#xff09; 五、邻接表&#xff08;有向图&#xff09; ——图的存储结构相比较线性表与树来说就复…

ChatGPT 总结数据分析的所有知识点

ChatGPT功能非常多,特别是对某个行业,某个方向,某个技术进行总结那是相当专业的。 如下图。 直接用一个指令便总结出来数据分析当中的所有知识点内容。 AIGC ChatGPT ,BI商业智能, 可视化Tableau, PowerBI, FineReport, 数据库Mysql Oracle, Office, Python ,ETL Ex…

UE4/UE5 动画控制

工程下载​ ​​​​​​​​​​​​​https://mbd.pub/o/bread/ZJ2cm5pu 蓝图控制sequence播放/倒播动画&#xff1a; 设置开启鼠标指针&#xff0c;开启鼠标事件 在场景中进行过场动画制作 设置控制事件

Excel VSTO开发5 -Excel对象结构

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 5 Excel对象结构 Excel提供了几个比较重要的对象&#xff1a; Application、Workbooks、Workbook、Worksheets、Worksheet 为了便…

I.MX RT1176笔记(9)-- 程序异常追踪(CmBacktrace 和 segger rtt)

前言 在使用 ARM Cortex-M 系列 MCU时候&#xff0c;有时候会遇到各种异常&#xff08;Hard Fault, Memory Management Fault, Bus Fault, Usage Fault, Debug Fault&#xff09;&#xff0c;这时候我们根据经验查询PC指针&#xff0c;LR寄存器&#xff0c;堆栈数据定位地址然…

2023 年全国大学生数学建模D题目-圈养湖羊的空间利用率

D题目应该是专科题目&#xff1f;&#xff1f;&#xff1f;不确定了 感觉类似一个细胞分裂问题一样&#xff0c;1&#xff0c;2&#xff0c;4&#xff0c;8&#xff0c; 题目1中规中矩 按照前面说的分配方法&#xff0c;一步一步计算进行 缺口的问题考虑反推回去&#xff0c…

Friend.tech热潮未过,在推特刷屏的TipCoin又是个啥?

Web3社交赛道风起云涌&#xff0c;Friend.tech的热潮还没过&#xff0c;最近又有一款名为Tip Coin社交项目在X&#xff08;前Twitter&#xff09;开始刷屏。 TipCoin作为一款社交类区块链项目依托于X平台&#xff0c;用户通过在X平台上发布内容来进行“挖矿”&#xff0c;获得项…

计算机安全学习笔记(IV):基于角色的访问控制 - RBAC

RBAC(Role-Based Access Control)基于用户在系统中设定的角色而不是用户的身份。一般来说&#xff0c;RBAC模型定义角色为组织中的一项工作职责&#xff0c;RBAC系统给角色而不是给单独的用户分配访问权。用户根据他们的职责被静态地或动态地分配给不同的角色。 RBAC模型间的关…

uniapp 路由不要显示#

在Uniapp中&#xff0c;路由默认使用的是hash模式&#xff0c;即在URL中添加#符号。如果你不想在URL中显示#&#xff0c;可以切换为使用history模式。 要在Uniapp中使用history模式&#xff0c;可以按照以下步骤进行操作&#xff1a; 打开manifest.json文件。在"app&qu…