扩散模型基础

扩散模型发展至今日,早已成为各大机器学习顶会的香饽饽。本文简记扩散模型入门相关代码,主要参阅李忻玮、苏步升等人所编著的《扩散模型从原理到实战》

文章目录

    • 1. 简单去噪模型
      • 1.1 简单噪声可视化
      • 1.2 去噪模型
      • 1.3 小结
    • 2 扩散模型
      • 2.1 采样过程
      • 2.2 上科技
        • 2.2.1 升级模型表征模块
        • 2.2.2 升级加噪过程
        • 2.2.3 改变预测目标
      • 2.3 小结

1. 简单去噪模型

这一小节中,我们将尝试设计一个去噪模型。首先,我们将展示如何给图片加噪。然后,我们将训练一个模型,对加噪图片进行去噪。

1.1 简单噪声可视化

第一步,导入环境

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

第二步,获取训练集样本。此处使用 MNIST 数据集

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # download=True for the first time
x, y = next(iter(train_dataloader))
print(f'Input shape: {x.shape}')
print(f'Label: {y}')

第三步,设计简单噪声函数

def corrupt(x, amount):noise = torch.rand_like(x)amount = amount.view(-1, 1, 1, 1)return x*(1-amount) + noise*amount

该函数以原始图像为输入,生成与该输入同样维度的随机噪声,并根据参数 amount 将噪声和原始图像混合。
最后,我们在 (0, 1) 之间采样 8 个 amount,看看不同加噪程度下的图片。
集合上面所有代码,如下:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltdef corrupt(x, amount):noise = torch.rand_like(x)amount = amount.view(-1, 1, 1, 1)return x*(1-amount) + noise*amountdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print(f'Input shape: {x.shape}')
print(f'Label: {y}')fig, axs = plt.subplots(2, 1, figsize=(9, 4))
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys')plt.savefig('visualize_corrupt.png', dpi=400)

结果如下:
请添加图片描述

1.2 去噪模型

所谓去噪模型,就是给模型加噪的图片,让模型直接预测真实图片。
我们搭建一个最简单的CV模型。

class BasicUNet(nn.Module):def __init__(self, in_channels=1, out_channels=1):super().__init__()self.down_layers = torch.nn.ModuleList([nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.Conv2d(64, 64, kernel_size=5, padding=2),])self.up_layers = torch.nn.ModuleList([nn.Conv2d(64, 64, kernel_size=5, padding=2),nn.Conv2d(64, 32, kernel_size=5, padding=2),nn.Conv2d(32, out_channels, kernel_size=5, padding=2),])self.act = nn.ReLU(inplace=True)self.downscale = nn.MaxPool2d(2)self.upscale = nn.Upsample(scale_factor=2)def forward(self, x):h = []for i, l in enumerate(self.down_layers):x = self.act(l(x))if i < 2:h.append(x)x = self.downscale(x)for i, l in enumerate(self.up_layers):if i > 0:x = self.upscale(x)x += h.pop()x = self.act(l(x))return x

初始化该模型,并查看参数:

net = BasicUNet()
print(sum([p.numel() for p in net.parameters()]))

输出为30w,可见该模型很小。
下面我们训练该模型:

batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)n_epochs = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []for epoch in range(n_epochs):for x, y in train_dataloader:x = x.to(device)noise_amount = torch.rand(x.shape[0]).to(device)noisy_x = corrupt(x, noise_amount)pred = net(noisy_x)loss = loss_fn(pred, x)opt.zero_grad()loss.backward()opt.step()losses.append(loss.item())avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()

模型训练完以后,我们跟上面加噪过程保持一致,分别设计 8 个不同程度损坏的照片,并让模型预测真实照片

x, y = next(iter(train_dataloader))
x = x[:8]fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)with torch.no_grad():preds = net(noised_x.to(device)).detach().cpu()axs[2].set_title('Network prediction')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')

本节完整代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltdef corrupt(x, amount):noise = torch.rand_like(x)amount = amount.view(-1, 1, 1, 1)return x*(1-amount) + noise*amountclass BasicUNet(nn.Module):def __init__(self, in_channels=1, out_channels=1):super().__init__()self.down_layers = torch.nn.ModuleList([nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.Conv2d(64, 64, kernel_size=5, padding=2),])self.up_layers = torch.nn.ModuleList([nn.Conv2d(64, 64, kernel_size=5, padding=2),nn.Conv2d(64, 32, kernel_size=5, padding=2),nn.Conv2d(32, out_channels, kernel_size=5, padding=2),])self.act = nn.ReLU(inplace=True)self.downscale = nn.MaxPool2d(2)self.upscale = nn.Upsample(scale_factor=2)def forward(self, x):h = []for i, l in enumerate(self.down_layers):x = self.act(l(x))if i < 2:h.append(x)x = self.downscale(x)for i, l in enumerate(self.up_layers):if i > 0:x = self.upscale(x)x += h.pop()x = self.act(l(x))return xdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)n_epochs = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []for epoch in range(n_epochs):for x, y in train_dataloader:x = x.to(device)noise_amount = torch.rand(x.shape[0]).to(device)noisy_x = corrupt(x, noise_amount)pred = net(noisy_x)loss = loss_fn(pred, x)opt.zero_grad()loss.backward()opt.step()losses.append(loss.item())avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()# simple check noise
x, y = next(iter(train_dataloader))
x = x[:8]fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)with torch.no_grad():preds = net(noised_x.to(device)).detach().cpu()axs[2].set_title('Network prediction')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')plt.savefig('test_v0.png', dpi=400)

结果如下:
请添加图片描述

1.3 小结

本节中,我们可视化了简单加噪过程,并搭建了简单去噪模型。在经过简单训练后,我们的模型可以成功识别加噪程度较低的图片,令人欣慰。但如何从虚无(完全随机照片)中,生成一张可辨别的图片呢?或许我们可以将预测真实图片的过程设计为迭代过程,一步步去噪(上图从右往左),这就是扩散模型的核心。

2 扩散模型

2.1 采样过程

在得到 1.2 节去噪结果后,一个很自然的想法是,我们可以将去噪过程分成多步,每次预测结果和输入结果进行叠合,如此多步迭代后,期望能变成最左边清晰的图像。该多步去噪的过程叫做采样过程。

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltdef corrupt(x, amount):noise = torch.rand_like(x)amount = amount.view(-1, 1, 1, 1)return x*(1-amount) + noise*amountclass BasicUNet(nn.Module):def __init__(self, in_channels=1, out_channels=1):super().__init__()self.down_layers = torch.nn.ModuleList([nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.Conv2d(64, 64, kernel_size=5, padding=2),])self.up_layers = torch.nn.ModuleList([nn.Conv2d(64, 64, kernel_size=5, padding=2),nn.Conv2d(64, 32, kernel_size=5, padding=2),nn.Conv2d(32, out_channels, kernel_size=5, padding=2),])self.act = nn.ReLU(inplace=True)self.downscale = nn.MaxPool2d(2)self.upscale = nn.Upsample(scale_factor=2)def forward(self, x):h = []for i, l in enumerate(self.down_layers):x = self.act(l(x))if i < 2:h.append(x)x = self.downscale(x)for i, l in enumerate(self.up_layers):if i > 0:x = self.upscale(x)x += h.pop()x = self.act(l(x))return xdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)n_epochs = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []for epoch in range(n_epochs):for x, y in train_dataloader:x = x.to(device)noise_amount = torch.rand(x.shape[0]).to(device)noisy_x = corrupt(x, noise_amount)pred = net(noisy_x)loss = loss_fn(pred, x)opt.zero_grad()loss.backward()opt.step()losses.append(loss.item())avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device)
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):with torch.no_grad():pred = net(x)pred_output_history.append(pred.detach().cpu())min_factor = 1/(n_steps - i)x = x*(1-min_factor)+pred*min_factorstep_history.append(x.detach().cpu())
print(len(step_history))fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)axs[0, 0].set_title('x (model input)')
axs[0, 1].set_title('model prediction')
for i in range(n_steps):axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i], cmap='Greys')[0].clip(0, 1), cmap='Greys')plt.savefig('test_real_v1.png', dpi=400)

结果如下:
请添加图片描述
左侧为每次迭代时,模型的输入。右侧为基于加噪数据预测真实数据的结果。
可以看到,虽然从虚无中,模型一步步迭代,形成了相对清晰的结果,但这些结果不像是数字。这说明我们的模型还有改进的空间。一个最简单的方法就是增加迭代的步数,期望能有所改善。
我们将采样步数增加到40(并微调噪声代码),采样过程更改为:

n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_steps))with torch.no_grad():pred = net(x)min_factor = 1/(n_steps - i)x = x*(1-min_factor) + pred*min_factor
fig, ax = plt.subplots(1, 1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')

下图为迭代40次结果,在 64 个样例中,已经可以依稀看到相对清晰的结果。
请添加图片描述

2.2 上科技

在健美运动中,只靠纯饮食是很难长出超乎常人的大块肌肉的。对此,健美圈大佬会摄入大量睾酮等激素,刺激身体发育。这种走捷径的方法常被叫做“上科技”。此处不是想教各位健身知识,而是借喻 AI 模型设计中走捷径的调包大法。

例如,上述结果很不理想,怎么快速提高模型表现呢?
我们可以 import 进来现成的模型嘛,大家是怎么训练的,咱也照做,保证短期内快速提高。照这个思路,我们需要升级以下几处:

2.2.1 升级模型表征模块

前面说了,我们的 BasicUNet 只有 30w 参数,只是一个玩具模型。我们可以直接 import 一个成熟的业内常用的模型,例如,UNet2DModel

from diffusers import UNet2DModelnet = UNet2DModel(sample_size=28,in_channels=1,out_channels=1,layers_per_block=2,block_out_channels=(32, 64, 64),down_block_types=('DownBlock2D',"AttnDownBlock2D","AttnDownBlock2D"),up_block_types=("AttnUpBlock2D","AttnUpBlock2D","UpBlock2D")
)

总代码:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltdef corrupt(x, amount):noise = torch.rand_like(x)amount = amount.view(-1, 1, 1, 1)return x*(1-amount) + noise*amountdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())a=1batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)n_epochs = 3
net = UNet2DModel(sample_size=28,in_channels=1,out_channels=1,layers_per_block=2,block_out_channels=(32, 64, 64),down_block_types=('DownBlock2D',"AttnDownBlock2D","AttnDownBlock2D"),up_block_types=("AttnUpBlock2D","AttnUpBlock2D","UpBlock2D")
)print(sum([p.numel() for p in net.parameters()]))
print(net)net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []for epoch in range(n_epochs):for x, y in train_dataloader:x = x.to(device)noise_amount = torch.rand(x.shape[0]).to(device)noisy_x = corrupt(x, noise_amount)pred = net(sample=noisy_x, timestep=0).sampleloss = loss_fn(pred, x)opt.zero_grad()loss.backward()opt.step()losses.append(loss.item())avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_steps))with torch.no_grad():pred = net(x)min_factor = 1/(n_steps - i)x = x*(1-min_factor) + pred*min_factor
fig, ax = plt.subplots(1, 1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')plt.savefig('test_v2.png', dpi=400)

效果如下:

请添加图片描述

2.2.2 升级加噪过程

其实有很多现成的加噪函数可以直接调。可以实现,前期加快点,后期加慢点的操作。

我们可以将加噪系数进行可视化:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltnoise_scheduler = DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r'${\sqrt{\bar{\alpha}_t}}$')
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu())**0.5, label=r'${1-\sqrt{\bar{\alpha}_t}}$')
plt.legend(fontsize="x-large")
plt.savefig('scheduler.png', dpi=400)

请添加图片描述

可视化加噪后的图片:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltnoise_scheduler = DDPMScheduler(num_train_timesteps=1000)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
x, y = next(iter(train_dataloader))
x = x.to(device)[:8]
x = x*2. - 1.
print(f'X shape {x.shape}')
fig, axs = plt.subplots(3, 1, figsize=(16, 10))axs[0].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0], cmap='Greys')
axs[0].set_title('clean X')timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(x)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
print(f'Noisy X shape {noisy_x.shape}')axs[1].imshow(torchvision.utils.make_grid(noisy_x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
axs[1].set_title('Noisy X (clipped to (-1, 1))')axs[2].imshow(torchvision.utils.make_grid(noisy_x.detach().cpu(), nrow=8)[0], cmap='Greys')
axs[2].set_title('Noisy X')plt.savefig('visualize_noise.png', dpi=400)

请添加图片描述

需要注意的是,很多时候的随机噪声是以 0 为期望,1 为方差的高斯分布。所以需要将原先的灰度范围(0~1之间)缩放到误差所在区间,即我们的 clean x,需要盖上一层灰(上图最下层最左侧)

2.2.3 改变预测目标

之前,所有的模型都是 去噪 模型。这些模型接收一个加噪的图片,输出真实图片。但实践表明,接收一个加噪图片,输出噪声,这样的预测目标有助于提升模型表现。
之前的推理过程:
加噪图片->模型->真实图片
现在是:
加噪图片->模型->噪声,
加噪图片-噪声->真实图片
相当于多了一步。至于为什么要多着一步,问就是这样精度高,大家都这样做。
下面将上述三条改进措施进行集成:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as pltdef corrupt(x, amount):noise = torch.rand_like(x)amount = amount.view(-1, 1, 1, 1)return x*(1-amount) + noise*amountdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())a=1batch_size = 100
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)n_epochs = 3
net = UNet2DModel(sample_size=28,in_channels=1,out_channels=1,layers_per_block=2,block_out_channels=(32, 64, 64),down_block_types=('DownBlock2D',"AttnDownBlock2D","AttnDownBlock2D"),up_block_types=("AttnUpBlock2D","AttnUpBlock2D","UpBlock2D")
)print(sum([p.numel() for p in net.parameters()]))
print(net)net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []noise_scheduler = DDPMScheduler(num_train_timesteps=1000)for epoch in range(n_epochs):for x, y in train_dataloader:x = x.to(device)x = x * 2. - 1.timesteps = torch.linspace(0, 999, batch_size).long().to(device)noise = torch.randn_like(x).to(device)noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# noise_amount = torch.rand(x.shape[0]).to(device)# noisy_x = corrupt(x, noise_amount)pred = net(sample=noisy_x, timestep=timesteps).sampleloss = loss_fn(pred, noise)opt.zero_grad()loss.backward()opt.step()losses.append(loss.item())avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()torch.save(net, f='model.pt')# net = torch.load('model.pt')
#
# x, y = next(iter(train_dataloader))
# x = x[:100].to(device)
# x = x * 2. - 1.
#
# timesteps = torch.linspace(0, 999, 100).long().to(device)
# noise = torch.randn_like(x).to(device)
# noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
# with torch.no_grad():
#     pred_noise = net(noisy_x, timesteps).sample
#     real_pred = noisy_x - pred_noise
#     real_pred = (real_pred + 1)/2
#
# fig, ax = plt.subplots(1, 1, figsize=(12,12))
# ax.imshow(torchvision.utils.make_grid(real_pred.detach().cpu(), nrow=10)[0].clip(0, 1), cmap='Greys')
#
# plt.savefig('test_final_real.png', dpi=400)net = torch.load('model.pt')
x = torch.rand(100, 1, 28, 28).to(device)
x = x * 2. - 1.timesteps = torch.linspace(0, 999, 100).long().to(device)
noise = torch.randn_like(x).to(device)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
with torch.no_grad():pred_noise = net(noisy_x, timesteps).samplereal_pred = noisy_x - pred_noisereal_pred = (real_pred + 1)/2fig, ax = plt.subplots(1, 1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(real_pred.detach().cpu(), nrow=10)[0].clip(0, 1), cmap='Greys')plt.savefig('test_final_random.png', dpi=400)

此处采样过程设计了两种。
第一种对应真实的情况:

  • sample 100 张图片
  • 模拟真实的训练环境,先对输入进行伸缩,映射到期望为 0 ,方差为 1 的区域
  • 对这些图片加噪,噪声比例随序列号增加
  • 使用模型预测噪声,并将加噪后的图片减去该噪声
  • 将照片映射回原来的 (0, 1) 区间

结果如下:
请添加图片描述

可以看到,在前排,加噪比例较低的情况下,模型能一步作出很好的预测。但后面图片加噪比例变高以后,模型预测质量变差,这凸显了迭代的意义。

此外,模型中的 timesteps 决定了模型在不同阶段对噪声的预测置信度。我们在设计模型时希望模型知晓自己处于迭代的什么阶段。刚开始时,模型更多的,可能只是复原一些背景噪声,随后直接预测具有语义信息的数字,最后对数字进行描边,逐渐精细化。

为验证这一点,我们可以模型同样随机的噪声,并告知模型所处的 timesteps,初始随机噪声减去模型预测噪声,即可看到模型真正想还原的语义:
请添加图片描述

可以看到:

  • timesteps 接近去噪初始阶段的右下角只是一些背景信息
  • 中间位置含有部分语义信息
  • timesteps 接近去噪最后阶段的左上角,模型预测显然集中于中心,进行精修,而不是右下角的普遍噪声。

2.3 小结

本节中,我们介绍了常用扩散模型相比简单去噪模型的主要改进之处。其中,迭代采样是最核心的思想,通过迭代采样,我们能够从虚无的噪声背景中逐渐还原出具有语义信息的图片。随后,我们通过调包,对去噪模型进行了全方位升级(又叫上科技)。最后,我们对上述改进之处进行了综合,并通过两个案例向大家介绍扩散模型在不同加噪阶段的还原能力及其蕴含的语义信息。

这篇博客写得囫囵吞枣,只是希望大家能快速上手,感受扩散模型的魅力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/228212.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Gin 源码深度解析及实现

介绍 什么是 gin &#xff1f; 一个轻量级高性能 HTTP Web 框架。 Introduction | Gin Web Framework (gin-gonic.com) Gin 是一个用 Go (Golang) 编写的 HTTP Web 框架。 它具有类似 Martini 的 API&#xff0c;但性能比 Martini 快 40 倍。 为什么使用 gin &#xff1f; In…

C#进阶-IIS应用程序池崩溃的解决方案

IIS是微软开发的Web服务器软件&#xff0c;被广泛用于Windows平台上的网站托管。在使用IIS过程中&#xff0c;可能会遇到应用程序池崩溃的问题&#xff0c;原因可能有很多&#xff0c;包括代码错误、资源不足、进程冲突等。本文将为大家介绍IIS应用程序池崩溃的问题分析和解决方…

目标检测损失函数:IoU、GIoU、DIoU、CIoU、EIoU、alpha IoU、SIoU、WIoU原理及Pytorch实现

前言 损失函数是用来评价模型的预测值和真实值一致程度&#xff0c;损失函数越小&#xff0c;通常模型的性能越好。不同的模型用的损失函数一般也不一样。损失函数主要是用在模型的训练阶段&#xff0c;如果我们想让预测值无限接近于真实值&#xff0c;就需要将损失值降到最低…

机器学习系列--R语言随机森林进行生存分析(1)

随机森林&#xff08;Breiman 2001a&#xff09;&#xff08;RF&#xff09;是一种非参数统计方法&#xff0c;需要没有关于响应的协变关系的分布假设。RF是一种强大的、非线性的技术&#xff0c;通过拟合一组树来稳定预测精度模型估计。随机生存森林&#xff08;RSF&#xff0…

HTML与CSS

目录 1、HTML简介 2、CSS简介 2.1选择器 2.1.1标签选择器 2.1.2类选择器 2.1.3层级选择器(后代选择器) 2.1.4id选择器 2.1.5组选择器 2.1.6伪类选择器 2.2样式属性 2.2.1布局常用样式属性 2.2.2文本常用样式属性 1、HTML简介 超文本标记语言HTML是一种标记语言&…

Python 从入门到精通之通俗易懂学闭包

系列 Python从入门到精通之安装与快速入门-CSDN博客 Python从入门到精通之基本数据类型和变量-CSDN博客 Python从入门到精通之集合&#xff08;List列表、Tuple元组、Dict字典、Set&#xff09;-CSDN博客 Python从入门到精通之条件语句、循环语句和函数-CSDN博客 Python从…

中央集成式架构量产时代,openVOC方案将引发软件开发模式变革

2024年&#xff0c;中央计算区域控制架构正式进入规模化量产周期&#xff0c;汽车智能化正式迈入2.0时代&#xff0c;产业生态、应用创新、开发模式都将迎来巨大变革。 同时&#xff0c;随着ChatGPT引发的AIGC领域的爆发式增长&#xff0c;人工智能技术掀起全球万亿级信息化应…

Cookie的详解使用(创建,获取,销毁)

文章目录 Cookie的详解使用&#xff08;创建&#xff0c;获取&#xff0c;销毁&#xff09;1、Cookie是什么2、cookie的常用方法3、cookie的构造和获取代码演示SetCookieServlet.javaGetCookieServlet.javaweb.xml运行结果如下 4、Cookie的销毁DestoryCookieServletweb.xml运行…

[OCR]Python 3 下的文字识别CnOCR

目录 1 CnOCR 2 安装 3 实践 1 CnOCR CnOCR 是 Python 3 下的文字识别&#xff08;Optical Character Recognition&#xff0c;简称OCR&#xff09;工具包。 工具包支持简体中文、繁体中文&#xff08;部分模型&#xff09;、英文和数字的常见字符识别&#xff0c;支持竖…

记一次接口交互is开头的属性序列化后“is”丢失问题

问题背景&#xff1a; 今天在做项目联调时调用别人的第三方接口时&#xff0c;发现字段传递不对导致参数传递异常的问题&#xff0c;当时还很奇怪&#xff0c;明白传好着呢&#xff0c;怎么就好端端的出现字段不对的情况呢&#xff1f; 查看发现该字段为boolean类型的isIsRef…

DsPdf:GcPdf 7.0 for NET Crack

DsPdf:GcPdf 7.0 用于全面文档控制的功能丰富的 C# .NET PDF API 库 PDF 文档解决方案&#xff08;DsPdf&#xff0c;以前称为 GcPdf&#xff09;可让您快速、高效地生成文档&#xff0c;且无需依赖任何内存。 在 C# .NET 中生成、加载、编辑和保存 PDF 文档 支持多种语言的全…

爬虫详细教程第1天

爬虫详细教程第一天 1.爬虫概述1.1什么是爬虫&#xff1f;1.2爬虫工具——Python1.3爬虫合法吗&#xff1f;1.4爬虫的矛与盾1.4.1反爬机制1.4.2反爬策略1.4.3robots.txt协议 2.爬虫使用的软件2.1使用的开发工具: 3.第一个爬虫4.web请求4.1讲解一下web请求的全部过程4.2页面渲染…

test mock-03-wiremock 模拟 HTTP 服务的开源工具 flexible and open source API mocking

拓展阅读 test 之 jmockit-01-overview jmockit-01-test 之 jmockit 入门使用案例 mockito-01-overview mockito 简介及入门使用 PowerMock Mock Server ChaosBlade-01-测试混沌工程平台整体介绍 jvm-sandbox 入门简介 wiremock WireMock是一个流行的开源工具&#xf…

git(安装,常用命令,分支操作,gitee,IDEA集成git,IDEA集成gitee,IDEA集成github,远程仓库操作)

文章目录 1. Git概述1.1 何为版本控制1.2 为什么需要版本控制1.3 版本控制工具1.4 Git简史1.5 Git工作机制1.6 Git和代码托管中心 2. Git安装3. Git常用命令3.1 设置用户签名3.1.1 说明3.1.2 语法3.1.3 案例实操 3.2 初始化本地库3.2.1 基本语法3.2.2 案例实操3.2.3 结果查看 3…

【瞎折腾/3D】无父物体下物体的旋转与移动

目录 说在前面移动World SpaceLocal Space 旋转World SpaceLocal Space 代码 说在前面 测试环境&#xff1a;Microsoft Edge 120.0.2210.91three.js版本&#xff1a;0.160.0其他&#xff1a;本篇文章中只探讨了无父对象下的移动与旋转&#xff0c;有父对象的情况将在下篇文章中…

Python中的用户交互函数详解,提升用户体验!

更多Python学习内容&#xff1a;ipengtao.com 用户进行交互的Python应用程序&#xff0c;有许多常用的用户交互函数可以帮助创建更具吸引力和友好的用户界面。本文将介绍一些常用的Python用户交互函数&#xff0c;并提供详细的示例代码&#xff0c;以帮助大家更好地理解它们的用…

kubeadm来搭建k8s集群。

我们采用了二进制包搭建出的k8s集群&#xff0c;本次我们采用更为简单的kubeadm的方式来搭建k8s集群。 二进制的搭建更适合50台主机以上的大集群&#xff0c;kubeadm更适合中小型企业的集群搭建 主机配置建议&#xff1a;2c 4G 主机节点 IP …

学习动态规划解决不同路径、最小路径和、打家劫舍、打家劫舍iii

学习动态规划|不同路径、最小路径和、打家劫舍、打家劫舍iii 62 不同路径 动态规划&#xff0c;dp[i][j]表示从左上角到(i,j)的路径数量dp[i][j] dp[i-1][j] dp[i][j-1] import java.util.Arrays;/*** 路径数量* 动态规划&#xff0c;dp[i][j]表示从左上角到(i,j)的路径数量…

【JavaScript】垃圾回收与内存泄漏

✨ 专栏介绍 在现代Web开发中&#xff0c;JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性&#xff0c;还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言&#xff0c;JavaScript具有广泛的应用场景&#x…

【ArcGIS微课1000例】0082:地震灾害图件制作之DEM晕渲图(山体阴影效果)

以甘肃积石山县6.2级地震为例,基于震中100km范围内的DEM数据,制作数字高程模型山体阴影晕渲图。 文章目录 一、效果展示二、实验数据三、晕渲图制作一、效果展示 基于数字高程模型制作的山体阴影晕渲图如下所示: 二、实验数据 本试验所需要的数据包括: 1. 震中位置矢量数…