diffusion model 简单demo

参考自:
Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
diffusion 简单demo
扩散模型之DDPM

核心公式和逻辑

在这里插入图片描述
在这里插入图片描述

q_x 计算公式,后面会用到:
在这里插入图片描述
推理:
在这里插入图片描述

代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve, make_swiss_roll
from PIL import Image
import torch
import io# get data
# s_curve, _ = make_s_curve(10**4 , noise=0.1)
# s_curve = s_curve[:, [0, 2]] / 10.0swiss_roll, _ = make_swiss_roll(10**4,noise=0.1)
s_curve = swiss_roll[:, [0, 2]]/10.0print('shape of moons: ', np.shape(s_curve))data = s_curve.T
fix, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolors='white', alpha=0.5)ax.axis('off')# plt.show()
plt.savefig('./s_curve.png')dataset = torch.Tensor(s_curve).float()# set params
num_steps = 100betas = torch.linspace(-6, 6, num_steps)    # # 逐渐递增
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5    # β0,β1,...,βtprint('beta: ', betas)alphas = 1 - betas
alphas_pro = torch.cumprod(alphas, 0)   # αt^ = αt的累乘# αt^往右平移一位, 原第t步的值维第t-1步的值, 第0步补1
alphas_pro_p = torch.cat([torch.tensor([1]).float(), alphas_pro[:-1]], 0)   # p表示previous, 即 αt-1^alphas_bar_sqrt = torch.sqrt(alphas_pro)    # αt^ 开根号
one_minus_alphas_bar_log = torch.log(1 - alphas_pro)    # log (1 - αt^)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_pro)  # 根号下(1-αt^)assert alphas.shape == alphas_pro.shape == alphas_pro_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shapeprint('beta: shape ', betas.shape)# diffusion processdef q_x(x_0, t):''' get q_x_{\t}作用: 可以基于x[0]得到任意时刻t的x[t]输入: x_0:初始干净图像; t:采样步输出: x_t:第t步时的x_0的样子'''noise = torch.randn_like(x_0) # 正态分布的随机噪声alphas_t = alphas_bar_sqrt[t]alphas_l_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_l_m_t * noise)# test add noise
num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 3))
plt.rc('text', color='blue')# 测试一下加噪下过
## 共有10000个点,每个点包含两个坐标
## 生成100步以内,每个5步加噪后图像for i in range(num_shows):j = i // 10k = i % 10q_i = q_x(dataset, torch.tensor(i * num_steps // num_shows))    # 生成t时刻的采样数据axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')axs[j, k].set_axis_off()axs[j, k].set_title('$q(\mathbf{x}_{' + str(i*num_steps // num_shows) + '})$')# plt.show()
plt.savefig('diffusion_process.png')# diffusion reverse process# --------------------- diffusion model -----------------import torch
import torch.nn as nnclass MLPDiffusion(nn.Module):def __init__(self, n_steps, num_units=32):super(MLPDiffusion, self).__init__()self.linears = nn.ModuleList([nn.Linear(2, num_units),nn.ReLU(),nn.Linear(num_units, num_units),nn.ReLU(),nn.Linear(num_units, num_units),nn.ReLU(),nn.Linear(num_units, 2)])self.step_embeddings = nn.ModuleList([nn.Embedding(n_steps, num_units),nn.Embedding(n_steps, num_units),nn.Embedding(n_steps, num_units),])def forward(self, x, t):"""模型的输入是加噪后的图片x和加噪step-> t, 输出是噪声"""for idx, embedding_layer in enumerate(self.step_embeddings):t_embedding = embedding_layer(t)x = self.linears[2 * idx](x)x += t_embeddingx = self.linears[2 * idx + 1](x)x = self.linears[-1](x) # shape: [10000, 2]return x# loss function
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, use_cuda=False):"""作用: 对任意时刻t进行采样计算loss参数:model: 模型x_0: 干净的图alphas_bar_sqrt: 根号下αt^one_minus_alphas_bar_sqrt: 根号下(1-αt^)n_steps: 采样步"""batch_size = x_0.shape[0]# 对一个batchsize样本生成随机的时刻t, 覆盖到更多不同的tt = torch.randint(0, n_steps, size=(batch_size//2,))  # 在0~99内生成整数采样步t = torch.cat([t, n_steps-1-t], dim=0)  # 一个batch的采样步, 尽量让生成的t不重复t = t.unsqueeze(-1)  # 扩展维度 -> [batchsize, 1]if use_cuda:t = t.cuda()# x0的系数a = alphas_bar_sqrt[t]  # 根号下αt^# eps的系数aml = one_minus_alphas_bar_sqrt[t]  # 根号下(1-αt^)# 生成随机噪音epse = torch.randn_like(x_0)if use_cuda:e = e.cuda()# 构造模型的输入x = x_0 * a + e * aml  # 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps# 送入模型,得到t时刻的随机噪声预测值output = model(x, t.squeeze(-1))  # 模型预测的是噪声, 噪声维度与x0一样大, [10000,2]# 与真实噪声一起计算误差,求平均值return (e - output).square().mean()# --------------- reverse process ---------------
def p_sample_loop(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt, use_cuda=False):"""作用: 从x[T]恢复x[T-1]、x[T-2]、...x[0]输入:model:模型shape:数据大小,用于生成随机噪声n_steps:逆扩散总步长betas: βtone_minus_alphas_bar_sqrt: 根号下(1-αt^)输出:x_seq: 一个序列的x, 即 x[T]、x[T-1]、x[T-2]、...x[0]"""if use_cuda:cur_x = torch.randn(shape).cuda()else:cur_x = torch.randn(shape)  # 随机噪声, 对应xtx_seq = [cur_x]for i in reversed(range(n_steps)):cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt, use_cuda=use_cuda)x_seq.append(cur_x)return x_seqdef p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt, use_cuda=False):"""作用: 从x[T]采样t时刻的重构值输入:model:模型x: 采样的随机噪声x[T]t: 采样步betas: βtone_minus_alphas_bar_sqrt: 根号下(1-αt^)输出:sample: 样本"""if use_cuda:t = torch.tensor([t]).cuda()else:t = torch.tensor([t])coeff = betas[t] / one_minus_alphas_bar_sqrt[t]  # 模型输出的系数:βt/根号下(1-αt^) = 1-αt/根号下(1-αt^)eps_theta = model(x, t)  # 模型的输出: εθ(xt, t)# (1/根号下αt) * (xt - (1-αt/根号下(1-αt^))*εθ(xt, t))mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))  if use_cuda:z = torch.randn_like(x).cuda()  # 对应公式中的 zelse:z = torch.randn_like(x)  # 对应公式中的 zsigma_t = betas[t].sqrt()  # 对应公式中的 σtsample = mean + sigma_t * zreturn (sample)# ----------- trainning ------------print('Training model...')
if_use_cuda = True
batch_size = 1024
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)
num_epoch = 4000
plt.rc('text',color='blue')model = MLPDiffusion(num_steps)  # 输出维度是2,输入是x和step
if if_use_cuda:model = model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)iteration = 0
for t in range(num_epoch):for idx, batch_x in enumerate(dataloader):# 损失计算if if_use_cuda:loss = diffusion_loss_fn(model, batch_x.cuda(), alphas_bar_sqrt.cuda(), one_minus_alphas_bar_sqrt.cuda(), num_steps, use_cuda=if_use_cuda)else:loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)optimizer.zero_grad()  # 梯度清零loss.backward()  # 损失回传torch.nn.utils.clip_grad_norm_(model.parameters(),1.)  # 梯度裁剪optimizer.step()iteration += 1# if iteration % 100 == 0:if(t % 100 == 0):print(f'epoch: {t} , loss: ', loss.item())if if_use_cuda:x_seq = p_sample_loop(model, dataset.shape, num_steps, betas.cuda(), one_minus_alphas_bar_sqrt.cuda(), use_cuda=True)else:x_seq = p_sample_loop(model, dataset.shape, num_steps, betas, one_minus_alphas_bar_sqrt, if_use_cuda)fig, axs = plt.subplots(1, 10, figsize=(28,3))for i in range(1, 11):cur_x = x_seq[i*10].cpu().detach()axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');axs[i-1].set_axis_off();axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')plt.savefig('./diffusion_train_tmp.png')### ----------------动画演示扩散过程和逆扩散过程-------------------------
# 前向过程
imgs = []
for i in range(100):plt.clf()q_i = q_x(dataset,torch.tensor([i]))plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);plt.axis('off');img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)imgs.append(img)# 逆向过程
reverse = []
for i in range(100):plt.clf()cur_x = x_seq[i].cpu().detach()plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);plt.axis('off')img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)reverse.append(img)print('save gif...')
imgs = imgs
imgs[0].save("diffusion_forward.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)imgs = reverse
imgs[0].save("diffusion_denoise.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/313417.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【devops】 阿里云挂载云盘 | 扩展系统硬盘 | 不重启服务器增加硬盘容量

扩容分区和文件系统(Linux) 文档地址 https://help.aliyun.com/zh/ecs/user-guide/extend-the-partitions-and-file-systems-of-disks-on-a-linux-instance?spm5176.smartservice_service_robot_chat_new.help.dexternal.4ac4f625Ol66kL#50541782adxmp…

C++ UML 类图介绍与设计

1 类图概述 UML(Unified Modeling Language),即统一建模语言,是用来设计软件的可视化建模语言。它的特点是简单、统一、图形化、能表达软件设计中的动态与静态信息。UML从目标系统的不同角度出发,定义了用例图、类图、对象图、状态图、活动图…

高效率改写文章,一键智能改写工具有妙招

如今,写作已经成为人们日常生活中不可或缺的一部分。无论是职场人士撰写工作报告,还是专业的作者创作文章,都离不开对文字的润色和改写。然而,随着工作量与时间压力的增加,如何在保证质量的前提下提高文章改写的效率成…

关于GDAL计算图像坐标的几个问题

关于GDAL计算图像坐标的几个问题_gdal读取菱形四角点坐标-CSDN博客 这篇文章写的很好,讲清楚了图像行列号与图像点坐标(x,y)对应关系,以及图像行列号如何转为地理坐标的,转载一下做个备份。 1.关于GDAL计算图像坐标的…

数据库服务的运行与登录

打开数据库服务 数据库服务: SQL Server(MSSQLServer) 运行在服务器端的应用程序, 提供数据的存储 / 处理和事务等在使用DBMS的客户端之前必须首先打开该服务 客户端连接到服务器 关于客户端 / 服务器端的说明 客户端 : 数据库管理系统(DBMS), 应用程序服务器端 : 安装的数据…

通过本机电脑远程访问路由器loopback的ip

实验拓扑图 本机电脑增加路由信息 正常设置telnet用户,然后通过本地电脑telnet 软件ensp中的设备,尝试是否可以正常访问即可 测试通过本地电脑可以正常访问ensp里面设备的loopback的ip地址了 最重要的一点是本机需要增加一条路由route add ip mask 下…

户外旅行摄影手册,旅游摄影完全攻略

一、资料前言 本套旅游摄影资料,大小295.47M,共有9个文件。 二、资料目录 《川藏线旅游摄影》杨桦.彩印版.pdf 《户外摄影指南》(Essential.Guide.to.Outdoor.photography.amateur)影印版.pdf 《旅行摄影大师班》(英)科尼什.扫描版.PDF 《旅行摄影…

迈向智能工厂:工业互联网时代的生产革命-亿发

随着各行各业数字化转型的加速推进,企业对于数据资产高效流动的需求日益增长,工业互联网网络也在数字经济发展中扮演着愈发重要的角色。什么是智能工厂简单来说就是基于万物互联的技术基础,实现企业内部数据的自由流动,从而促进数…

【高阶数据结构】哈希表 {哈希函数和哈希冲突;哈希冲突的解决方案:开放地址法,拉链法;红黑树结构 VS 哈希结构}

一、哈希表的概念 顺序结构以及平衡树 顺序结构以及平衡树中,元素关键码与其存储位置之间没有对应的关系。因此在查找一个元素时,必须要经过关键码的多次比较。顺序查找时间复杂度为O(N);平衡树中为树的高度,即O(log_2 N)&#xf…

【第1节】书生·浦语大模型全链路开源开放体系

目录 1 简介2 内容(1)书生浦语大模型发展历程(2)体系(3)亮点(4)全链路体系构建a.数据b 预训练c 微调d 评测e.模型部署f.agent 智能体 3 相关论文解读4 ref 1 简介 书生浦语 InternLM…

深入理解GCC/G++在CentOS上的应用

文章目录 深入理解GCC/G在CentOS上的应用编译C和C源文件C语言编译C语言编译 编译过程的详解预处理编译汇编链接 链接动态库和静态库静态库和动态库安装静态库 结论 深入理解GCC/G在CentOS上的应用 在前文的基础上,我们已经了解了CentOS的基本特性和如何在其上安装及…

Python零基础从小白打怪升级中~~~~~~~多线程

线程安全和锁 一、全局解释器锁 首先需要明确的一点是GIL并不是Python的特性,它是在实现Python解析器(CPython)时所引入的一个概念。 GIL全称global interpreter lock,全局解释器锁。 每个线程在执行的时候都需要先获取GIL,保证同一时刻只…

IDEA plugins 好用的插件集

IDEA plugins RestfulToolkit 1. 安装插件 File–>Settings --> plugins --> RestfulToolkit 2.插件有点: 2.1、帮助把项目中的 RestURL 按照项目汇总出来,找到对应URL直接在IDEA上面进行请求测试。 2.2、开发Java Web页面项目,经…

学习笔记------时序约束之时钟周期约束

本文摘自《VIVADO从此开始》高亚军 主时钟周期约束 主时钟,即从FPGA的全局时钟引脚进入的时钟或者由高速收发器输出的时钟。 对于时钟约束,有三个要素描述:时钟源,占空比和时钟周期。 单端时钟输入 这里我们新建一个工程&#x…

【Proteus】51单片机对直流电机的控制

直流电机:输出或输入为直流电能的旋转电机。能实现直流电能和机械能互相转换的电机。把它作电动机运行时是直流电动机,电能转换为机械能;作发电机运行时是直流发电机,机 械能转换为电能。 直流电机的控制: 1、方向控制…

手撕AVL树(map和set底层结构)(1)

troop主页 今日鸡汤:Action may out always bring happiness;but there is no happiness without action. 行动不一定能带来快乐,但不行动一定不行 C之路还很长 手撕AVL树 一 AVL树概念二 模拟实现AVL树2.1 AVL节点的定义 三 插入更新平衡因子&#xff0…

vim相关指令

vim的各种模式及其转换关系图 vim 默认处于命令模式!!! 模式之间转换的指令 除【命令模式】之外,其它模式要切换到【命令模式】,只需要无脑 ESC 即可!!! [ 命令模式 ] 切换至 [ 插…

联合体共用体--第二十三天

1.结构体元素有各自单独的空间 共用体元素共享空间,空间大小由最大类型确定 2.结构体元素互不影响,共用体赋值会导致覆盖

javaWeb智能医疗管理系统

简介 在当今快节奏的生活中,智能医疗系统的崛起为医疗行业带来了一场革命性的变革。基于JavaWeb技术开发的智能医疗管理系统,不仅为医疗机构提供了高效、精准的管理工具,也为患者提供了更便捷、更个性化的医疗服务。本文将介绍一个基于SSM&a…

一些重新开始面试之后的八股文汇总

一、内存中各项名词说明 1、机器内存概念说明 linux中的free命令可以查看机器的内存使用情况,vmstat命令也可以 其中不容易被理解的是: 内存缓冲/存数(buffer/cached) 1.buffers和cache也是RAM划分出来的一部分地址空间 2.buff…