结合代码详细讲解DDPM的训练和采样过程

本篇文章结合代码讲解Denoising Diffusion Probabilistic Models(DDPM),首先我们先不关注推导过程,而是结合代码来看一下训练和推理过程是如何实现的,推导过程会在别的文章中讲解;首先我们来看一下论文中的算法描述。DDPM分为扩散过程和反向扩散过程,也就是训练过程和采样过程;
代码来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

请添加图片描述

1. 训练(扩散)过程

首先我们来逐个看一下训练过程中的所有符号的含义:

x 0 x_0 x0是真实图像;

t 是扩散的步数,取值范围从1到T;

ϵ \epsilon ϵ是从标准正态分布中采样的噪声;

ϵ θ \epsilon_\theta ϵθ是模型,用于预测噪声,其输入是 x t x_t xt和 t;

x t x_t xt的表达式如下:

在这里插入图片描述

x t x_t xt x 0 x_0 x0加噪获得,其中 α t ‾ \overline{\alpha_{t}} αt是常数
因此训练过程总结成一句话就是,向真实图像 x 0 x_0 x0中加噪,获得加噪后的图像 x t x_t xt;然后将 x t x_t xt和t输入到网络中,得到预测的噪声,通过使得网络预测的噪声和真实加入的噪声更接近,完成网络的训练。
从另一个角度,我们也可以这么理解:向 x 0 x_0 x0中加噪的过程,可以理解成是编码的过程,加噪之后获取到了图像的中间表示 x t x_t xt;而预测噪声的过程则是从 x t x_t xt解码的过程,只是并没有选择直接解码出 x 0 x_0 x0,而是解码出加入的噪声,也就是残差。请添加图片描述

下面来看一下代码,跟上面讲解的过程是一一对应的,首先在初始化函数中我们需要准备好每个时刻t所需要的常数量 α t ‾ \sqrt{\overline{\alpha_{t}}} αt 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt 。这些参数最原始来源于一个超参数 β t \beta_t βt,这个参数为加入噪声的方差。他们的关系如下:

[图片]

所以很容易理解代码中的sqrt_alphas_bar就是 α t ‾ \sqrt{\overline{\alpha_{t}}} αt ,sqrt_one_minus_alphas_bar 就是 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt
接着在forward函数中,首先从[0,T]中随机选取一个时刻t,然后从标准正态分布中采样一个噪声,shape和 x 0 x_0 x0一致,接着获取 x t x_t xt

x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

然后将然后将 x t x_t xt和t输入到网络中,得到预测的噪声:

self.model(x_t, t)

计算Loss函数:

loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')

训练过程的完整代码:

class GaussianDiffusionTrainer(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))# 每次forward时,给每个样本随机取一个t,并采样一个高斯噪声,然后根据t从sqrt_alphas_bar和sqrt_one_minus_alphas_bar中取出对应的系数,然后根据x_0和采样的高斯噪声生成x_t。然后将x_t和t输入到噪声预测网络中,得到预测的噪声。预测出的噪声输入到网络中,计算loss,从而实现model的训练。def forward(self, x_0):"""Algorithm 1."""t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 给batch中每个样本取一个t,取值范围是[0, 1000]noise = torch.randn_like(x_0) # 采样高斯噪声,shape与x_0一致x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')return loss

2. 推理(反向)过程

首先我们来明确一下,反向过程的目标是什么。反向过程的目标是逐步从一张噪声图像 x T x_T xT中恢复出一张图像,表示成 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),我们没法推导出 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),但是 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t, x_0) p(xt1xt,x0)是可以用贝叶斯公式推导出来的,其也是一个高斯分布,并且可以把 x 0 x_0 x0化简掉。最终 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布的均值为:
请添加图片描述

方差为 β t \beta_t βt
因此我们可以从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布中采样出一个 x t − 1 x_{t-1} xt1
请添加图片描述
这种采样方式叫做重参数技巧,如果不了解可以看如下介绍:
在这里插入图片描述
注意:是标准差与标准正态分布相乘,而不是方差;

因为DDPM的方差固定为 β t \beta_t βt,所以反向过程的重点就是学习出这个分布的方差,从上面的表达式可以看出分布的均值与 x t x_t xt和当前时刻加入的噪声 ϵ t \epsilon_t ϵt有关,而我们的模型可以完成对 ϵ t \epsilon_t ϵt的预测,只要将 x t x_t xt和 t 输入进去模型中即可。代码中描述的过程与此一一对应。

注意代码中存在三个噪声,其中eps是模型预测出来的,其和分布的均值计算相关;forward函数中的noise也是噪声,但是它是从标准正态分布中采样的,用于从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)采样;forward函数中的 x T x_T xT是整个反向过程的输入,也是从标准正态分布中采样的。

# 反向过程是从纯噪声x_T开始逐步去噪以生成样本,此过程也是一个高斯分布,均值和x_t以及预测出的噪声相关,方差在ddpm中没有进行学习,直接使用的是后验分布q(x_t-1|x_t,x_0)的方差。
class GaussianDiffusionSampler(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]self.register_buffer('coeff1', torch.sqrt(1. / alphas))self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))def predict_xt_prev_mean_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shapereturn (extract(self.coeff1, t, x_t.shape) * x_t -extract(self.coeff2, t, x_t.shape) * eps)def p_mean_variance(self, x_t, t):# below: only log_variance is used in the KL computationsvar = torch.cat([self.posterior_var[1:2], self.betas[1:]])var = extract(var, t, x_t.shape)eps = self.model(x_t, t)xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)return xt_prev_mean, vardef forward(self, x_T):"""Algorithm 2."""x_t = x_T # 输入是一个标准正态分布噪声# 从T到1进行reverse过程for time_step in reversed(range(self.T)):print(time_step)t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_stepmean, var= self.p_mean_variance(x_t=x_t, t=t) # no noise when t == 0if time_step > 0:noise = torch.randn_like(x_t)else:noise = 0x_t = mean + torch.sqrt(var) * noise # 从q(x_t-1|x_t)中采样assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."x_0 = x_treturn torch.clip(x_0, -1, 1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/412654.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

<C++> AVLTree

目录 1. AVL概念 2. AVL树节点的定义 3. AVL树的插入 4. AVL树的旋转 5. AVL树的验证 6. AVL树的删除 7. AVL树的性能 暴力搜索、二分搜索、二叉搜索树、二叉平衡搜索树(AVL、红黑树)、多叉平衡搜索树(B树)、哈希表 1. AVL概念 二…

【C++ Primer Plus习题】7.2

问题: 解答: #include <iostream> using namespace std;#define MAX 10int input(float* grade, int len) {int i 0;for (i 0; i < len; i){cout << "请输入第" << i 1 << "个高尔夫成绩(按0结束):";cin >> grade[i]…

更改了ip地址怎么改回来

在日常的网络使用中&#xff0c;‌我们有时会因为特定的需求更改设备的IP地址&#xff0c;‌比如解决IP冲突、‌访问特定网络资源或进行网络测试等。‌然而&#xff0c;‌更改IP地址后&#xff0c;‌我们可能又因为某些原因需要将IP地址改回原来的设置。‌本文将详细介绍如何改…

视频号单场直播GMV超500万!开学季助力品牌高效转化

开学在即&#xff0c;友望数据发现&#xff0c;不少学习机、学练机、智能机器人、词典笔等学习相关的电子教育产品开始畅销 ▲ 图片来源&#xff1a;友望数据-商品排行榜 新学年开始&#xff0c;家长们又要为孩子新的学业操碎心&#xff0c;而教育培训商家也在开学季迎来了他们…

PS如何抠人像图--5步实现完美抠图

1、菜单栏--选择--选择主体 2、菜单栏--选择--选择并遮住 3、选择原图--右下角添加纯色背景 4、文件--导出--导出为png图片 5、原图与抠图效果对比 相关参考视频&#xff1a; 【ps教程】揭秘PS抠头发&#xff0c;这才是真正的教学&#xff0c;快收藏吧_哔哩哔哩_bilibili 一分…

挂载5T大容量外接硬盘到ubuntu

挂载5T大容量外接硬盘到ubuntu S1&#xff1a;查看硬盘 使用 $ sudo fdisk -l找到对应盘&#xff0c;例如下图所示 /dev/sdc S2: 创建分区 使用 $ sudo fdisk /dev/sdc对上硬盘进行创建分区&#xff1b;可以依次使用以下指令 m &#xff1a;查看命令&#xff1b; g &…

从开题到答辩:ChatGPT超全提示词分享!(下)【建议收藏】

数据收集 1. "请帮我找出关于如何收集【研究领域】社交媒体数据进行消费者行为研究的五篇指导性文章&#xff0c;并概述它们的主要方法论摘要。" 2. "我需要对【特定领域】市场的消费者偏好进行调查。能否提供一份包含调查问卷设计原则和示例的草稿&#xff1f;…

cola_os学习笔记(下)

cola_os学习笔记&#xff08;上&#xff09; os文件夹 cola_device.c ​ .h放在.c的同层级。作者采用了字符设备注册的方式&#xff0c;在.h中可以看到设备属性。也就是把LED这些设备抽象&#xff0c;外面传入"LED1"这样的参数&#xff0c;使我联想到java的new一个…

编译错误cc:not found总结

一、错误 cc: not found 系统无法找到名为cc的编译器。 注&#xff1a;在大多数Linux系统中&#xff0c;cc通常是C编译器的链接&#xff08;link&#xff09;或别名&#xff0c;它通常指向gcc&#xff08;GNU Compiler Collection&#xff09;或其他C编译器。 二、可能导致…

「OC」CAlayer——巧用动画实现一个丝滑的折叠cell

「OC」CAlayer——巧用动画实现一个丝滑的折叠cell 前言 在这个暑假集训后的时间&#xff0c;都在家里做着学习笔记的整理&#xff0c;深入学习了CALayer的相关知识&#xff0c;掌握了第三方库Masonry自动布局的用法&#xff0c;以及学习了MVC的相关内容&#xff0c;正好组内…

chapter08-面向对象编程——(Object类详解)——day09

目录 319-运算符 320-查看Jdk源码 321-子类重写equals 322-equals课堂练习1 323-equals重写练习2 324-equals重写练习3 325-hashCode 326-toString 327-finalize 319-运算符 引用的都是同一个地址&#xff0c;所以返回true 320-查看Jdk源码 equals只能判断引用类型是…

艾体宝干货丨Redis与MongoDB的区别

Redis&#xff08;Remote Dictionary Server&#xff0c;远程字典服务器&#xff09;和 MongoDB 是两类知名的 NoSQL数据库&#xff0c;其以非结构化的方式存储数据。与传统关系数据库使用表格、行和列来组织数据不同&#xff0c;NoSQL数据库采用了不同的数据存储模型。Redis是…

探索极速Python:Sanic框架的魔力

文章目录 探索极速Python&#xff1a;Sanic框架的魔力背景&#xff1a;为什么选择Sanic&#xff1f;Sanic是什么&#xff1f;如何安装Sanic&#xff1f;简单的库函数使用方法场景应用示例常见Bug及解决方案总结 探索极速Python&#xff1a;Sanic框架的魔力 背景&#xff1a;为什…

【位置编码】【Positional Encoding】直观理解位置编码!把位置编码想象成秒针!

【位置编码】【Positional Encoding】直观理解位置编码&#xff01;把位置编码想象成秒针&#xff01; 你们有没有好奇过为啥位置编码非得长成这样&#xff1a; P E ( p o s , 2 i ) s i n ( p o s 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i 1 ) c o s ( p o s 1000 …

AcWing895. 最长上升子序列

这个代码不知道怎么说&#xff0c;反正就是对着代码手算一次就懂了&#xff0c;无需多言&#xff0c;就是俩for循环里面的第二层for的循环条件是j<i,j是从下标1往下标i-1遍历的&#xff0c;每次a【j】<a【i】就在答案数组f【i】上面做出更新。基本的输入样例已经可以覆盖…

红黑树刨析(删除部分)

文章目录 红黑树删除节点情景分析情景1&#xff1a;删除节点左右子树都为空情景1.1&#xff1a;删除节点为红色情景1.2&#xff1a;删除节点为黑色情况1.2.1&#xff1a;删除节点的兄弟节点是红色情景1.2.2&#xff1a;删除节点的兄弟节点是黑色情景1.2.2.1&#xff1a;删除节点…

Cpp学习手册-基础学习

首先你要去网上下载对应的运行软件&#xff0c;先把对应的 C 环境配置好&#xff0c;配置好了我们就可以开始我们的C 学习之旅了。希望通过学习我们能够成为一个比较不错的 C 开发工程师。我也会持续更新 C 知识。 1. C语法基础 当我通过 CLion 工具创建了一个新的 Project 。…

Redis中的 大/热 key问题 ,如何解决(面试版)

big key 什么是 big key? big key&#xff1a;就是指一个内存空间占用比较大的键(Key) 造成的问题&#xff1a; 内存分布不均。在集群模式下&#xff0c;不同 slot分配到不同实例中&#xff0c;如果大 key 都映射到一个实例&#xff0c;则分布不均&#xff0c;查询效率也…

自建电商网站整合Refersion教程

前言&#xff1a;   先介绍一下Refersion有啥用&#xff0c;如果你有一个自己的跨境电商独立站点&#xff0c;想找一些网红帮忙推广销售自己的商品&#xff0c;然后按照转化订单比例给网红支付佣金&#xff0c;这件事情对双方来说透明性和实时性很重要&#xff0c;Refersion就…

C++ | Leetcode C++题解之第382题链表随机节点

题目&#xff1a; 题解&#xff1a; class Solution {ListNode *head;public:Solution(ListNode *head) {this->head head;}int getRandom() {int i 1, ans 0;for (auto node head; node; node node->next) {if (rand() % i 0) { // 1/i 的概率选中&#xff08;替…