从开始实现扩散概率模型 PyTorch 实现

目录

一、说明

二、从头开始实施

三、线性噪声调度器

四、时间嵌入

五、下层DownBlock类块

六、中间midBlock类块

七、UpBlock上层类块

八、UNet 架构

九、训练

十、采样

十一、配置(Default.yaml)

十二、数据集 (MNIST)


keyword:  Diffusion Probabilistic Models 

一、说明

        扩散过程由前向阶段组成,其中图像通过在每个步骤中添加高斯噪声逐渐损坏。经过许多步骤后,图像实际上变得与从正态分布中采样的随机噪声无法区分。这是通过在每个时间步骤 xₜ 应用过渡函数来实现的,其中 β 表示在 t-1 时添加到图像中的预定噪声量,以产生 t 时的图像。

        在前面的讨论中,我们确定设置 α=1−β 并计算每个时间步骤中这些 α 值的累积乘积,使我们能够在任何给定步骤 t 直接从原始图像过渡到噪声版本。在反向过程中,模型被训练以近似反向分布。由于正向和反向过程都是高斯的,因此目标是让模型预测反向分布的均值和方差。

        通过详细的推导,从最大化观测数据的对数似然性这一目标出发,我们得出需要最小化真实去噪分布(以 x₀ 为条件)与模型预测分布之间的 KL 散度(以特定均值和方差为特征)。方差固定为与目标分布的方差匹配,而均值则以相同形式重写。最小化 KL 散度简化为最小化预测噪声与实际噪声样本之间的平方差。

训练过程包括对图像进行采样、选择时间步长 t,以及添加从正态分布中采样的噪声。然后将 t 处的噪声图像传递给模型。从噪声时间表得出的累积乘积项确定随时间增加的噪声。损失函数是原始噪声样本与模型预测之间的均方误差 (MSE)。

二、从头开始实施

        对于图像生成,我们从学习到的反向分布中进行采样,从正态分布中的随机噪声样本 xₜ 开始。使用与 xₜ 和预测噪声相同的公式计算平均值,方差与地面真实去噪分布相匹配。使用重新参数化技巧,我们反复从这个反向分布中采样以生成 x₀。在 x₀ 处,没有添加额外的噪声;相反,平均值直接作为最终输出返回。

        为了实现扩散过程,我们需要处理正向和反向阶段的计算。我们将创建一个噪声调度程序来管理这些任务。在正向过程中,给定一个图像、一个噪声样本和一个时间步长 t,调度程序将使用正向方程返回图像的噪声版本。为了优化效率,它将预先计算并存储 α(1−β) 的值以及所有时间步长中 α 的累积乘积。

        作者采用了线性噪声调度,其中 β 在 1,000 个时间步骤内从 1×10⁻⁴ 线性缩放到 0.02。调度程序还处理反向过程:给定 xt 和模型预测的噪声,它将通过从反向分布中采样来计算 xₜ₋₁。这涉及使用各自的方程计算均值和方差,并通过重新参数化技巧生成样本。

        为了支持这些计算,调度程序还将存储 1-αₜ、1-累积乘积项以及该项的平方根的预先计算的值。

三、线性噪声调度器

import torchclass LinearNoiseScheduler:def __init__(self, num_timesteps, beta_start, beta_end):self.num_timesteps = num_timestepsself.beta_start = beta_startself.beta_end = beta_endself.betas = torch.linspace(beta_start, beta_end, num_timesteps)self.alphas = 1. - self.betasself.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

使用传递给此类的参数初始化所有参数后,我们将定义 β 值从起始范围到结束范围线性增加,确保 βₜ 从 0 进展到最后的时间步骤。接下来,我们将设置正向和反向过程方程所需的所有变量。

  def add_noise(self, original, noise, t):original_shape = original.shapebatch_size = original_shape[0]sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)for _ in range(len(original_shape) - 1):sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)for _ in range(len(original_shape) - 1):sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)# Apply and Return Forward process equationreturn (sqrt_alpha_cum_prod.to(original.device) * original+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)

add_noise()函数表示正向过程。它以原始图像、噪声样本和时间步长 ttt 作为输入。图像和噪声的维度为 b×h×w,而时间步长为大小为 b 的一维张量。对于正向过程,我们计算给定时间步长的累积乘积项的平方根和 1-累积乘积项。这些值被重新整形为维度 b×1×1×1。最后,我们应用正向过程方程来生成噪声图像。

    def sample_prev_timestep(self, xt, noise_pred, t):x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))x0 = torch.clamp(x0, -1., 1.)mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])if t == 0:return mean, x0else:variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])variance = variance * self.betas.to(xt.device)[t]sigma = variance ** 0.5z = torch.randn(xt.shape).to(xt.device)return mean + sigma * z, x0

        调度程序类中的下一个函数处理反向过程。它使用噪声图像 xₜ、模型的噪声预测和时间步长 t 作为输入,从学习到的反向分布中生成样本。我们保存原始图像预测 x₀​ 以供可视化,它是通过重新排列正向过程方程以使用噪声预测而不是实际噪声来计算 x₀ 获得的。

        对于逆向过程中的采样,我们使用逆均值方程计算均值。在 t=0 时,我们只需返回均值。对于其他时间步骤,噪声会添加到均值中,方差与以 x₀​ 为条件的地面真实去噪分布的方差相同。最后,我们使用计算出的均值和方差从高斯分布中采样,应用重新参数化技巧来生成结果。

        这样就完成了噪声调度程序,它管理添加噪声的正向过程和采样的反向过程。对于扩散模型,我们可以灵活地选择任何架构,只要它满足两个关键要求。第一,输入和输出形状必须相同,第二,必须有一种方法可以整合时间步长信息。

作者图片

        无论是在训练期间还是采样期间,时间步长信息始终是可访问的。包含此信息有助于模型更好地预测原始噪声,因为它表明输入图像中有多少是噪声。我们不仅向模型提供图像,还提供相应的时间步长。

        对于模型架构,我们将使用 UNet,这也是原作者的选择。为了确保一致性,我们将复制 Hugging Face 的 Diffusers 管道中使用的稳定扩散 UNet 中实现的块、激活、规范化和其他组件的精确规格。

作者图片

        时间步长由时间嵌入块处理,该块采用大小为b(批次大小)的时间步长的一维张量,并输出批次中每个时间步长的大小为t_emb_dim的表示。此块首先通过嵌入空间将整数时间步长转换为矢量表示。然后,此嵌入通过中间带有激活函数的两个线性层,产生最终的时间步长表示。对于嵌入空间,作者使用了 Transformers 中常用的正弦位置嵌入方法。在整个架构中,使用的激活函数是 S 形线性单元 (SiLU),但也可以选择其他激活函数。

作者图片

        UNet架构遵循简单的编码器-解码器设计。编码器由多个下采样块组成,每个块都会减少输入的空间维度(通常减半),同时增加通道数量。最终下采样块的输出由中间块的几层处理,所有层都以相同的空间分辨率运行。随后,解码器采用上采样块,逐步增加空间维度并减少通道数量,最终匹配原始输入大小。在解码器中,上采样块通过残差跳过连接以相同的分辨率集成相应下采样块的输出。虽然大多数扩散模型都遵循这种通用的 UNet 架构,但它们在各个块内的具体细节和配置上有所不同。

作者图片

        大多数变体中的下行块通常由ResNet 块、后跟自注意力块和下采样层组成。每个 ResNet 块都使用一系列操作构建:组归一化、激活层和卷积层。此序列的输出将通过另一组归一化、激活和卷积层。通过将第一个归一化层的输入与第二个卷积层的输出相结合来添加残差连接。这个完整的序列形成ResNet 块,可以将其视为通过残差连接连接的两个卷积块。

        在 ResNet 块之后,有一个规范化步骤、一个自注意力层和另一个残差连接。虽然模型通常使用多个 ResNet 层和自注意力层,但为简单起见,我们的实现将只使用每个层的一层。

        为了整合时间信息,每个 ResNet 块都包含一个激活层,后面跟着一个线性层,用于处理时间嵌入表示。时间嵌入表示为大小为t_emb_dim的张量,通过此线性层将其投影到与卷积层输出具有相同大小和通道数的张量中。这样就可以通过在空间维度上复制时间步长表示,将时间嵌入添加到卷积层的输出中。

作者图片

        另外两个块使用相同的组件,只是略有不同。上块完全相同,只是它首先将输入上采样为两倍空间大小,然后在整个通道维度上集中相同空间分辨率的下块输出。然后我们有相同的 resnet 层和自注意力块。中间块的层始终将输入保持为相同的空间分辨率。hugging face 版本首先有一个 resnet 块,然后是自注意力层和 resnet 层。对于这些 resnet 块中的每一个,我们都有一个时间步长投影层。现有的时间步长表示会经过这些块,然后被添加到 resnet 的第一个卷积层的输出中。

四、时间嵌入

import torch
import torch.nn as nndef get_time_embedding(time_steps, temb_dim):assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"# factor = 10000^(2i/d_model)factor = 10000 ** ((torch.arange(start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)))# pos / factor# timesteps B -> B, 1 -> B, temb_dimt_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factort_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)return t_emb

第一个函数为给定的时间步长get_time_embedding生成时间嵌入。它受到 Transformer 模型中使用的正弦位置嵌入的启发。
time_steps:时间步长值的张量(形状:[B]其中B是批次大小)。每个值代表批次元素的一个离散时间步长。
temb_dim:时间嵌入的维数。这决定了每个时间步长的生成嵌入的大小。

确保这temb_dim是均匀的,因为正弦嵌入需要将嵌入分成两半,分别表示正弦和余弦分量。无缝扩展以处理任何批量大小或嵌入维度。

五、下层DownBlock类块

class DownBlock(nn.Module):def __init__(self, in_channels, out_channels, t_emb_dim,down_sample=True, num_heads=4, num_layers=1):super().__init__()self.num_layers = num_layersself.down_sample = down_sampleself.resnet_conv_first = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, in_channels if i == 0 else out_channels),nn.SiLU(),nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,kernel_size=3, stride=1, padding=1),)for i in range(num_layers)])self.t_emb_layers = nn.ModuleList([nn.Sequential(nn.SiLU(),nn.Linear(t_emb_dim, out_channels))for _ in range(num_layers)])self.resnet_conv_second = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, out_channels),nn.SiLU(),nn.Conv2d(out_channels, out_channels,kernel_size=3, stride=1, padding=1),)for _ in range(num_layers)])self.attention_norms = nn.ModuleList([nn.GroupNorm(8, out_channels)for _ in range(num_layers)])self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True)for _ in range(num_layers)])self.residual_input_conv = nn.ModuleList([nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)for i in range(num_layers)])self.down_sample_conv = nn.Conv2d(out_channels, out_channels,4, 2, 1) if self.down_sample else nn.Identity()def forward(self, x, t_emb):out = xfor i in range(self.num_layers):# Resnet block of Unetresnet_input = outout = self.resnet_conv_first[i](out)out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]out = self.resnet_conv_second[i](out)out = out + self.residual_input_conv[i](resnet_input)# Attention block of Unetbatch_size, channels, h, w = out.shapein_attn = out.reshape(batch_size, channels, h * w)in_attn = self.attention_norms[i](in_attn)in_attn = in_attn.transpose(1, 2)out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)out = out + out_attnout = self.down_sample_conv(out)return out

DownBlock 类结合了ResNet 块自注意力块和可选的下采样,并集成了时间嵌入来整合时间步长信息。将卷积层与残差连接相结合,以实现更好的梯度流和更高效的学习。将时间步长表示投影到特征空间中,使模型能够整合时间相关信息。通过对所有空间位置之间的关系进行建模来捕获长距离依赖关系。减少空间维度以专注于更深层中更大规模的特征。

参数

  • in_channels:输入通道数。
  • out_channels:输出通道数。
  • t_emb_dim:时间嵌入的维度。
  • down_sample:布尔值,确定是否在块末尾应用下采样。
  • num_heads:多头注意力层中的注意力头的数量。
  • num_layers:此块中的 ResNet + 注意力层的数量。

ResNet块

  • resnet_conv_first:ResNet 块的第一个卷积层。
  • t_emb_layers:时间嵌入投影层。
  • resnet_conv_second:ResNet 块的第二个卷积层。
  • residual_input_conv:用于残差连接的 1x1 卷积。

自注意力模块

  • attention_norms:在注意力机制之前对规范化层进行分组。
  • attentions:多头注意力层。

下采样

  • down_sample_conv:应用卷积来减少空间维度(如果down_sample=True)。

Forward Pass 方法定义了如何x通过块处理输入张量:out初始化为输入x。对于每一层,我们都有 ResNet Block 和 Self-Attention Block。

在 ResNet Block 中,我们有第一个 卷积层,它应用 GroupNorm、SiLU 激活和 3x3 卷积,以及一个时间嵌入函数,它将时间嵌入传递t_emb到线性层(投影到out_channels),并将此投影时间嵌入添加到out(在空间维度上广播)。然后我们有第二个卷积和一个残差连接,它将原始输入(resnet_input)添加到第二个卷积的输出。

在自注意力模块中,我们将空间维度扁平化为一个维度(h * w)以用于注意力机制。规范化输入并转置以匹配注意力层输入格式。多头注意力in_attn使用查询、键和值执行自注意力。重塑回转置并重塑回原始空间维度。残差连接下采样。

六、中间midBlock类块

class MidBlock(nn.Module):def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):super().__init__()self.num_layers = num_layersself.resnet_conv_first = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, in_channels if i == 0 else out_channels),nn.SiLU(),nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,padding=1),)for i in range(num_layers+1)])self.t_emb_layers = nn.ModuleList([nn.Sequential(nn.SiLU(),nn.Linear(t_emb_dim, out_channels))for _ in range(num_layers + 1)])self.resnet_conv_second = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, out_channels),nn.SiLU(),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),)for _ in range(num_layers+1)])self.attention_norms = nn.ModuleList([nn.GroupNorm(8, out_channels)for _ in range(num_layers)])self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True)for _ in range(num_layers)])self.residual_input_conv = nn.ModuleList([nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)for i in range(num_layers+1)])def forward(self, x, t_emb):out = x# First resnet blockresnet_input = outout = self.resnet_conv_first[0](out)out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]out = self.resnet_conv_second[0](out)out = out + self.residual_input_conv[0](resnet_input)for i in range(self.num_layers):# Attention Blockbatch_size, channels, h, w = out.shapein_attn = out.reshape(batch_size, channels, h * w)in_attn = self.attention_norms[i](in_attn)in_attn = in_attn.transpose(1, 2)out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)out = out + out_attn# Resnet Blockresnet_input = outout = self.resnet_conv_first[i+1](out)out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]out = self.resnet_conv_second[i+1](out)out = out + self.residual_input_conv[i+1](resnet_input)return out

该类MidBlock是位于扩散模型中 U-Net 架构中间的模块。它由ResNet 块自注意力层组成,并集成了时间嵌入来处理时间信息。这是用于去噪扩散等任务的模型的重要组成部分。此外,我们还有:

  • 时间嵌入:通过将时间信息(例如,扩散模型中的去噪步骤)投影到特征空间并将其添加到卷积特征中来合并时间信息。
  • 层迭代:在注意力ResNet 块之间交替,按num_layers这些组合的顺序处理输入。

七、UpBlock上层类块

class UpBlock(nn.Module):def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):super().__init__()self.num_layers = num_layersself.up_sample = up_sampleself.resnet_conv_first = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, in_channels if i == 0 else out_channels),nn.SiLU(),nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,padding=1),)for i in range(num_layers)])self.t_emb_layers = nn.ModuleList([nn.Sequential(nn.SiLU(),nn.Linear(t_emb_dim, out_channels))for _ in range(num_layers)])self.resnet_conv_second = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, out_channels),nn.SiLU(),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),)for _ in range(num_layers)])self.attention_norms = nn.ModuleList([nn.GroupNorm(8, out_channels)for _ in range(num_layers)])self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True)for _ in range(num_layers)])self.residual_input_conv = nn.ModuleList([nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)for i in range(num_layers)])self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,4, 2, 1) \if self.up_sample else nn.Identity()def forward(self, x, out_down, t_emb):x = self.up_sample_conv(x)x = torch.cat([x, out_down], dim=1)out = xfor i in range(self.num_layers):resnet_input = outout = self.resnet_conv_first[i](out)out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]out = self.resnet_conv_second[i](out)out = out + self.residual_input_conv[i](resnet_input)batch_size, channels, h, w = out.shapein_attn = out.reshape(batch_size, channels, h * w)in_attn = self.attention_norms[i](in_attn)in_attn = in_attn.transpose(1, 2)out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)out = out + out_attnreturn out

该类UpBlock是 U-Net 类架构的解码器阶段的一部分,通常用于扩散模型或其他图像生成/分割任务。它结合了上采样跳过连接ResNet 块自注意力来重建输出图像,同时保留早期编码器阶段的细粒度细节。

  • 上采样:通过转置卷积(ConvTranspose2d)实现,以增加特征图的空间分辨率。
  • 跳过连接:允许解码器重用编码器的详细特征,帮助重建。
  • ResNet Block:使用卷积层处理输入,集成时间嵌入,并添加残差连接以实现高效的梯度流。
  • 自我注意力:捕获远程空间依赖关系以保留全局上下文。
  • 时间嵌入:对时间信息进行编码并将其注入特征图,这对于处理动态数据的模型(如扩散模型)至关重要。

八、UNet 架构

class Unet(nn.Module):def __init__(self, model_config):super().__init__()im_channels = model_config['im_channels']self.down_channels = model_config['down_channels']self.mid_channels = model_config['mid_channels']self.t_emb_dim = model_config['time_emb_dim']self.down_sample = model_config['down_sample']self.num_down_layers = model_config['num_down_layers']self.num_mid_layers = model_config['num_mid_layers']self.num_up_layers = model_config['num_up_layers']assert self.mid_channels[0] == self.down_channels[-1]assert self.mid_channels[-1] == self.down_channels[-2]assert len(self.down_sample) == len(self.down_channels) - 1# Initial projection from sinusoidal time embeddingself.t_proj = nn.Sequential(nn.Linear(self.t_emb_dim, self.t_emb_dim),nn.SiLU(),nn.Linear(self.t_emb_dim, self.t_emb_dim))self.up_sample = list(reversed(self.down_sample))self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))self.downs = nn.ModuleList([])for i in range(len(self.down_channels)-1):self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,down_sample=self.down_sample[i], num_layers=self.num_down_layers))self.mids = nn.ModuleList([])for i in range(len(self.mid_channels)-1):self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,num_layers=self.num_mid_layers))self.ups = nn.ModuleList([])for i in reversed(range(len(self.down_channels)-1)):self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))self.norm_out = nn.GroupNorm(8, 16)self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)def forward(self, x, t):# Shapes assuming downblocks are [C1, C2, C3, C4]# Shapes assuming midblocks are [C4, C4, C3]# Shapes assuming downsamples are [True, True, False]# B x C x H x Wout = self.conv_in(x)# B x C1 x H x W# t_emb -> B x t_emb_dimt_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)t_emb = self.t_proj(t_emb)down_outs = []for idx, down in enumerate(self.downs):down_outs.append(out)out = down(out, t_emb)# down_outs  [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]# out B x C4 x H/4 x W/4for mid in self.mids:out = mid(out, t_emb)# out B x C3 x H/4 x W/4for up in self.ups:down_out = down_outs.pop()out = up(out, down_out, t_emb)# out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]out = self.norm_out(out)out = nn.SiLU()(out)out = self.conv_out(out)# out B x C x H x Wreturn out

该类是U-Net 架构Unet的实现,专为图像处理任务而设计,例如分割或生成,通常用于扩散模型。该网络包括下采样中级处理上采样阶段。它利用时间嵌入执行动态任务(例如扩散模型),利用跳过连接保留空间信息,利用 GroupNorm 进行归一化。

作者图片

  • 时间嵌入:实现时间动态。
  • 跳过连接:通过连接将细粒度的空间细节集成到解码器中。
  • 灵活的架构:允许通过model_config不同的深度、分辨率和功能丰富度进行定制。
  • 规范化和激活:GroupNorm 确保稳定的训练,而 SiLU 激活则改善非线性。
  • 输出一致性:确保输出图像保留原始的空间尺寸和通道数。

九、训练

import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseSchedulerdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def train(args):with open(args.config_path, 'r') as file:try:config = yaml.safe_load(file)except yaml.YAMLError as exc:print(exc)print(config)diffusion_config = config['diffusion_params']dataset_config = config['dataset_params']model_config = config['model_params']train_config = config['train_params']# Create the noise schedulerscheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],beta_start=diffusion_config['beta_start'],beta_end=diffusion_config['beta_end'])# Create the datasetmnist = MnistDataset('train', im_path=dataset_config['im_path'])mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4)# Instantiate the modelmodel = Unet(model_config).to(device)model.train()# Create output directoriesif not os.path.exists(train_config['task_name']):os.mkdir(train_config['task_name'])# Load checkpoint if foundif os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):print('Loading checkpoint as found one')model.load_state_dict(torch.load(os.path.join(train_config['task_name'],train_config['ckpt_name']), map_location=device))# Specify training parametersnum_epochs = train_config['num_epochs']optimizer = Adam(model.parameters(), lr=train_config['lr'])criterion = torch.nn.MSELoss()# Run trainingfor epoch_idx in range(num_epochs):losses = []for im in tqdm(mnist_loader):optimizer.zero_grad()im = im.float().to(device)# Sample random noisenoise = torch.randn_like(im).to(device)# Sample timestept = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)# Add noise to images according to timestepnoisy_im = scheduler.add_noise(im, noise, t)noise_pred = model(noisy_im, t)loss = criterion(noise_pred, noise)losses.append(loss.item())loss.backward()optimizer.step()print('Finished epoch:{} | Loss : {:.4f}'.format(epoch_idx + 1,np.mean(losses),))torch.save(model.state_dict(), os.path.join(train_config['task_name'],train_config['ckpt_name']))print('Done Training ...')if __name__ == '__main__':parser = argparse.ArgumentParser(description='Arguments for ddpm training')parser.add_argument('--config', dest='config_path',default='config/default.yaml', type=str)args = parser.parse_args()train(args)

加载配置:从 YAML 文件读取训练配置(如数据集路径、超参数和模型设置)。

设置组件

  • 初始化噪声调度器,用于在不同的时间步添加噪声。
  • 创建一个MNIST 数据集加载器
  • 实例化U-Net模型

检查点管理:检查现有检查点,如果可用则加载。创建保存检查点和输出所需的目录。

训练循环:每个时期:

  • 遍历数据集,根据采样的时间步长向图像添加噪声。
  • 使用模型预测噪声并计算损失(预测噪声和实际噪声之间的 MSE)。
  • 使用反向传播更新模型参数并保存模型检查点。

优化:使用 Adam 优化器和 MSE 损失函数来训练模型。

完成:打印 epoch 损失并在每个 epoch 结束时保存模型。

十、采样

import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseSchedulerdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def sample(model, scheduler, train_config, model_config, diffusion_config):xt = torch.randn((train_config['num_samples'],model_config['im_channels'],model_config['im_size'],model_config['im_size'])).to(device)for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):# Get prediction of noisenoise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))# Use scheduler to get x0 and xt-1xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))# Save x0ims = torch.clamp(xt, -1., 1.).detach().cpu()ims = (ims + 1) / 2grid = make_grid(ims, nrow=train_config['num_grid_rows'])img = torchvision.transforms.ToPILImage()(grid)if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):os.mkdir(os.path.join(train_config['task_name'], 'samples'))img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))img.close()def infer(args):# Read the config file #with open(args.config_path, 'r') as file:try:config = yaml.safe_load(file)except yaml.YAMLError as exc:print(exc)print(config)diffusion_config = config['diffusion_params']model_config = config['model_params']train_config = config['train_params']# Load model with checkpointmodel = Unet(model_config).to(device)model.load_state_dict(torch.load(os.path.join(train_config['task_name'],train_config['ckpt_name']), map_location=device))model.eval()# Create the noise schedulerscheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],beta_start=diffusion_config['beta_start'],beta_end=diffusion_config['beta_end'])with torch.no_grad():sample(model, scheduler, train_config, model_config, diffusion_config)if __name__ == '__main__':parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')parser.add_argument('--config', dest='config_path',default='config/default.yaml', type=str)args = parser.parse_args()infer(args)

加载配置:从 YAML 文件读取模型、扩散和训练参数。

模型设置:加载训练好的 U-Net 模型检查点。初始化噪声调度程序以指导反向扩散过程。

采样过程

  • 从随机噪声开始,并在指定的时间步内迭代地对其进行去噪。
  • 在每个时间步:
  • 使用模型预测噪音。
  • 使用调度程序计算去噪图像(x0)并更新当前噪声图像(xt)。
  • 将中间去噪图像作为 PNG 文件保存在输出目录中。

推理:执行采样过程并保存结果而不改变模型。

十一、配置(Default.yaml)

dataset_params:im_path: 'data/train/images'diffusion_params:num_timesteps : 1000beta_start : 0.0001beta_end : 0.02model_params:im_channels : 1im_size : 28down_channels : [32, 64, 128, 256]mid_channels : [256, 256, 128]down_sample : [True, True, False]time_emb_dim : 128num_down_layers : 2num_mid_layers : 2num_up_layers : 2num_heads : 4train_params:task_name: 'default'batch_size: 64num_epochs: 40num_samples : 100num_grid_rows : 10lr: 0.0001ckpt_name: 'ddpm_ckpt.pth'

该配置文件提供了扩散模型的训练和推理的设置。

数据集参数im_path:指定训练图像的路径( )。

扩散参数:设置扩散过程的时间步数和噪声参数的范围(beta_startbeta_end)。

模型参数

  • 定义模型架构,包括:
  • 输入图像通道(im_channels)和大小(im_size)。
  • 下采样、中间处理和上采样的通道数。
  • 每一级是否发生下采样(down_sample)。
  • 各种块的嵌入尺寸和层数。

训练参数

  • 指定训练配置,如任务名称、批量大小、时期、学习率和检查点文件名。
  • 包括采样设置,例如用于可视化的样本数量和网格行数。

十二、数据集 (MNIST)

import glob
import osimport torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Datasetclass MnistDataset(Dataset):self.split = splitself.im_ext = im_extself.images, self.labels = self.load_images(im_path)def load_images(self, im_path):assert os.path.exists(im_path), "images path {} does not exist".format(im_path)ims = []labels = []for d_name in tqdm(os.listdir(im_path)):for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):ims.append(fname)labels.append(int(d_name))print('Found {} images for split {}'.format(len(ims), self.split))return ims, labelsdef __len__(self):return len(self.images)def __getitem__(self, index):im = Image.open(self.images[index])im_tensor = torchvision.transforms.ToTensor()(im)# Convert input to -1 to 1 range.im_tensor = (2 * im_tensor) - 1return im_tensor

初始化:采用分割名称、图像文件扩展名(im_ext)和图像路径(im_path)。调用load_images以加载图像路径及其相应的标签。

图像加载load_images遍历 处的目录结构im_path,假设子目录已标记(例如,数字类别的01、...)。收集图像文件路径并根据文件夹名称分配标签。

数据集长度__len__返回图像的总数。

数据检索__getitem__通过索引检索图像,将其转换为张量,并将像素值缩放到范围 -1,1-1,1-1,1。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/491868.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CCF-GESP 等级考试 2024年12月认证C++七级真题解析

2024年12月真题 一、单选题(每题2分,共30分) 正确答案:D 解析:考察字符类型和ASCII码值。 字符类型参与运算,是它所对应的ASCII码值在参与运算,运算结果为整数值。小写字母 b 的ASCII码为98&am…

递归实现指数型枚举(递归)

92. 递归实现指数型枚举 - AcWing题库 每个数有选和不选两种情况 我们把每个数看成每层,可以画出一个递归搜索树 叶子节点就是我们的答案 很容易写出每dfs函数 dfs传入一个u表示层数 当层数大于我们n时,去判断每个数字的选择情况,输出被选…

事务-介绍与操作四大特性

一.数据准备: 1.员工表: -- 员工管理 create table tb_emp (id int unsigned primary key auto_increment comment ID,username varchar(20) not null unique comment 用户名,password varchar(32) default 123456 comment 密码,n…

[白月黑羽]关于风机协议工具的解答

架构 python3.8pyqt5 先来看下原题: 视频中软件的效果 先来看下程序的效果如何,看上去大概相似 对应代码已经上传到了gitcode https://gitcode.com/m0_37662818/fan_protocol_tool/overview 实现中的难点是双悬浮可视化,同时要高亮悬浮对…

HCIA-Access V2.5_4_1_1路由协议基础_IP路由表

大型网络的拓扑结构一般会比较复杂,不同的部门,或者总部和分支可能处在不同的网络中,此时就需要使用路由器来连接不同的网络,实现网络之间的数据转发。 本章将介绍路由协议的基础知识、路由表的分类、静态路由基础与配置、VLAN间…

ISCTF复现-misc

File_Format 下载附件后用010打开查看文件头会发现是个exe文件 格式:文件描述(后缀名),文件头(hex):文件头标识(十六进制)PNG (png),文件头(hex):89504E47 PNGImageFile…

Windows设置所有软件默认以管理员身份运行

方法一、修改注册表 winr打开运行,输入“regedit”打开注册表; 打开此路径“计算机HKEY_LOCAL_MACHINESOFTWAREMicrosoftWindowsCurrentVersionPoliciesSystem”; 在右侧找到“EnableLUA”,将其值改为0,重启电脑。 …

【题解】—— LeetCode一周小结50

🌟欢迎来到 我的博客 —— 探索技术的无限可能! 🌟博客的简介(文章目录) 【题解】—— 每日一道题目栏 上接:【题解】—— LeetCode一周小结49 9.判断国际象棋棋盘中一个格子的颜色 题目链接:…

Docker安全性与最佳实践

一、引言:Docker安全性的重要性 Docker作为一种容器化技术,已成为现代应用程序部署和开发的核心工具。然而,随着容器化应用的普及,Docker的安全性问题也日益突出。容器本身的隔离性、网络配置、权限管理等方面的安全隐患&#xf…

利用notepad++删除特定关键字所在的行

1、按组合键Ctrl H,查找模式选择 ‘正则表达式’,不选 ‘.匹配新行’ 2、查找目标输入 : ^.*关键字.*\r\n (不保留空行) ^.*关键字.*$ (保留空行)3、替换为:(空) 配置界面参考下图: ​​…

上传图片的预览

解决:在上传图片时,1显示已有的图片 2显示准备替换的图片 前 后 在这个案例中可以预览到 【已有与准备替换】 2张图片 具体流程 1创建一个共享组件 与manage.py同级别路径的文件 manage.py custom_widgets.py# custom_widgets.py from django import forms from dja…

MySQL学习之DDL操作

目录 数据库的操作 创建 查看 选择 删除 修改 数据类型 表的创建 表的修改 表的约束 主键 PRIMARY KEY 唯一性约束 UNIQUE 非空约束 NOT NULL 外键约束 约束小结 索引 索引分类 常规索引 主键索引 唯一索引 外键索引 优点 缺点 视图 创建 删除 修改…

国际网络专线是什么?有什么优势?

国际网络专线作为一种独立的网络连接方式,通过卫星或海底光缆等物理链路,将全球不同国家和地区的网络直接互联,为企业提供了可靠的通信渠道。本文将详细探讨国际网络专线的优势以及其广泛的应用场景。 国际网络专线的优势解析 1. 专属连接&am…

密码编码学与网络安全(第五版)答案

通过如下代码分别统计一个字符的频率和三个字符的频率,"8"——"e",“;48”——“the”,英文字母的相对使用频率,猜测频率比较高的依此为),t,*,5,分别对应s,o,n,…

【功能安全】随机硬件失效导致违背安全目标的评估(FMEDA)

目录 01 随机硬件失效介绍 02 FMEDA介绍 03 FMEDA模板 01 随机硬件失效介绍 GBT 34590 part5

mybatis 的动态sql 和缓存

动态SQL 可以根据具体的参数条件,来对SQL语句进行动态拼接。 比如在以前的开发中,由于不确定查询参数是否存在,许多人会使用类似于where 1 1 来作为前缀,然后后面用AND 拼接要查询的参数,这样,就算要查询…

Web APIs - 第5章笔记

目标: 依托 BOM 对象实现对历史、地址、浏览器信息的操作或获取 具备利用本地存储实现学生就业表案例的能力 BOM操作 综合案例 JavaScript的组成 ECMAScript: 规定了js基础语法核心知识。 比如:变量、分支语句、循环语句、对象等等 Web APIs : DO…

AI视频配音技术创新应用与商业机遇

随着人工智能技术的飞速发展,AI视频配音技术已经成为内容创作者和营销人员的新宠。这项技术不仅能够提升视频内容的吸引力,还能为特定行业带来创新的解决方案。本文将探讨AI视频配音技术的应用场景,并讨论如何合法合规地利用这一技术。 AI视频…

vlan和vlanif

文章目录 1、为什么会有vlan的存在2、vlan(虚拟局域网)1、vlan原理1. 为什么这样划分了2、如何实现不同交换机相同的vlan实现互访呢3、最优化的解决方法,vlan不同交换机4、vlan标签和vlan数据帧 5、vlan实现2、基于vlan的划分方式1、基于接口的vlan划分方式2、基于m…

Java每日一题(1)

给定n个数a1,a2,...an,求它们两两相乘再相加的和。 即:Sa1*a2a1*a3...a1*ana2*a3...an-2*an-1an-2*anan-1*an 第一行输入的包含一个整数n。 第二行输入包含n个整数a1,a2,...an。 样例输入 4 1 3 6 9 样例输出 117 答案 import java.util.Scanner; // 1:无…