扩散模型实战(四):从零构建扩散模型

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

本文以MNIST数据集为例,从零构建扩散模型,具体会涉及到如下知识点:

  • 退化过程(向数据中添加噪声)
  • 构建一个简单的UNet模型
  • 训练扩散模型
  • 采样过程分析

下面介绍具体的实现过程:

一、环境配置&python包的导入
     

最好有GPU环境,比如公司的GPU集群或者Google Colab,下面是代码实现:

# 安装diffusers库!pip install -q diffusers# 导入所需要的包import torchimport torchvisionfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom diffusers import DDPMScheduler, UNet2DModelfrom matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')
# 输出Using device: cuda

此时会输出运行环境是GPU还是CPU

载MNIST数据集

       MNIST数据集是一个小数据集,存储的是0-9手写数字字体,每张图片都28X28的灰度图片,每个像素的取值范围是[0,1],下面加载该数据集,并展示部分数据:

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)x, y = next(iter(train_dataloader))print('Input shape:', x.shape)print('Labels:', y)plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 输出Input shape: torch.Size([8, 1, 28, 28])Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])

扩散模型的退化过程

       所谓退化过程,其实就是对输入数据加入噪声的过程,由于MNIST数据集的像素范围在[0,1],那么我们加入噪声也需要保持在相同的范围,这样我们可以很容易的把输入数据与噪声进行混合,代码如下:

def corrupt(x, amount):  """Corrupt the input `x` by mixing it with noise according to `amount`"""  noise = torch.rand_like(x)  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works  return x*(1-amount) + noise*amount

接下来,我们看一下逐步加噪的效果,代码如下:

# Plotting the input datafig, axs = plt.subplots(2, 1, figsize=(12, 5))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')# Adding noiseamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Plottinf the noised versionaxs[1].set_title('Corrupted data (-- amount increases -->)')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

       从上图可以看出,从左到右加入的噪声逐步增多,当噪声量接近1时,数据看起来像纯粹的随机噪声。

构建一个简单的UNet模型

       UNet模型与自编码器有异曲同工之妙,UNet最初是用于完成医学图像中分割任务的,网络结构如下所示:

代码如下:

class BasicUNet(nn.Module):    """A minimal UNet implementation."""    def __init__(self, in_channels=1, out_channels=1):        super().__init__()        self.down_layers = torch.nn.ModuleList([             nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),            nn.Conv2d(32, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 64, kernel_size=5, padding=2),        ])        self.up_layers = torch.nn.ModuleList([            nn.Conv2d(64, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 32, kernel_size=5, padding=2),            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),         ])        self.act = nn.SiLU() # The activation function        self.downscale = nn.MaxPool2d(2)        self.upscale = nn.Upsample(scale_factor=2)    def forward(self, x):        h = []        for i, l in enumerate(self.down_layers):            x = self.act(l(x)) # Through the layer and the activation function            if i < 2: # For all but the third (final) down layer:              h.append(x) # Storing output for skip connection              x = self.downscale(x) # Downscale ready for the next layer                      for i, l in enumerate(self.up_layers):            if i > 0: # For all except the first up layer              x = self.upscale(x) # Upscale              x += h.pop() # Fetching stored output (skip connection)            x = self.act(l(x)) # Through the layer and the activation function                    return x

我们来检验一下模型输入输出的shape变化是否符合预期,代码如下:

net = BasicUNet()x = torch.rand(8, 1, 28, 28)net(x).shape
# 输出torch.Size([8, 1, 28, 28])

再来看一下模型的参数量,代码如下:

sum([p.numel() for p in net.parameters()])
# 输出309057

至此,已经完成数据加载和UNet模型构建,当然UNet模型的结构可以有不同的设计。

、扩散模型训练

        扩散模型应该学习什么?其实有很多不同的目标,比如学习噪声,我们先以一个简单的例子开始,输入数据为带噪声的MNIST数据,扩散模型应该输出对应的最佳数字预测,因此学习的目标是预测值与真实值的MSE,训练代码如下:

# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = BasicUNet()net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x)        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# View the loss curveplt.plot(losses)plt.ylim(0, 0.1);
# 输出Finished epoch 0. Average loss for this epoch: 0.024689Finished epoch 1. Average loss for this epoch: 0.019226Finished epoch 2. Average loss for this epoch: 0.017939

训练过程的loss曲线如下图所示:

六、扩散模型效果评估

我们选取一部分数据来评估一下模型的预测效果,代码如下:

#@markdown Visualizing model predictions on noisy inputs:# Fetch some datax, y = next(iter(train_dataloader))x = x[:8] # Only using the first 8 for easy plotting# Corrupt with a range of amountsamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Get the model predictionswith torch.no_grad():  preds = net(noised_x.to(device)).detach().cpu()# Plotfig, axs = plt.subplots(3, 1, figsize=(12, 7))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')axs[2].set_title('Network Predictions')axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

从上图可以看出,对于噪声量较低的输入,模型的预测效果是很不错的,当amount=1时,模型的输出接近整个数据集的均值,这正是扩散模型的工作原理。

Note:我们的训练并不太充分,读者可以尝试不同的超参数来优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/98886.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Mysql+Vue+Django的协同过滤和内容推荐算法的智能音乐推荐系统——深度学习算法应用(含全部工程源码)+数据集

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境MySQL环境VUE环境 模块实现1. 数据请求和储存2. 数据处理计算歌曲、歌手、用户相似度计算用户推荐集 3. 数据存储与后台4. 数据展示 系统测试工程源代码下载其它资料下载 前言 本项目以丰富的网易云音乐数据为基…

一文彻底理解时间复杂度和空间复杂度(附实例)

目录 1 PNP&#xff1f;2 时间复杂度2.1 常数阶复杂度2.2 对数阶复杂度2.3 线性阶复杂度2.4 平方阶复杂度2.5 指数阶复杂度2.6 总结 3 空间复杂度 1 PNP&#xff1f; P类问题(Polynomial)指在多项式时间内能求解的问题&#xff1b;NP类问题(Non-Deterministic Polynomial)指在…

深入理解分布式架构,构建高效可靠系统的关键

深入探讨分布式架构的核心概念、优势、挑战以及构建过程中的关键考虑因素。 引言什么是分布式架构&#xff1f;分布式架构的重要性 分布式系统的核心概念节点和通信数据分区与复制一致性与一致性模型负载均衡与容错性 常见的分布式架构模式客户端-服务器架构微服务架构事件驱动…

python从入门到精通——完整教程

阅读全文点击《python从入门到精通——完整教程》 一、编程入门与进阶提高 Python编程入门 1、Python环境搭建&#xff08; 下载、安装与版本选择&#xff09;。 2、如何选择Python编辑器&#xff1f;&#xff08;IDLE、Notepad、PyCharm、Jupyter…&#xff09; 3、Pytho…

JetBrains IDE远程开发功能可供GitHub用户使用

JetBrains与GitHub去年已达成合作&#xff0c;提供GitHub Codespaces 与 JetBrains Gateway 之间的集成。 GitHub Codespaces允许用户创建安全、可配置、专属的云端开发环境&#xff0c;此集成意味着您可以通过JetBrains Gateway使用在 GitHub Codespaces 中运行喜欢的IDE进行…

JavaWeb-Listener监听器

目录 监听器Listener 1.功能 2.监听器分类 3.监听器的配置 4.ServletContext监听 5.HttpSession监听 6.ServletRequest监听 监听器Listener 1.功能 用于监听域对象ServletContext、HttpSession和ServletRequest的创建&#xff0c;与销毁事件监听一个对象的事件&#x…

面试热题(不同的二分搜索树)

给你一个整数 n &#xff0c;求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种&#xff1f;返回满足题意的二叉搜索树的种数。 经典的面试题&#xff0c;这部分涉及了组合数学中的卡特兰数&#xff0c;如果对其不清楚的同学可以去看我以前的博客卡特兰数 …

安防监控/视频集中存储/云存储平台EasyCVR v3.3增加首页告警类型

安防监控/视频集中存储/云存储EasyCVR视频汇聚平台&#xff0c;可支持海量视频的轻量化接入与汇聚管理。平台能提供视频存储磁盘阵列、视频监控直播、视频轮播、视频录像、云存储、回放与检索、智能告警、服务器集群、语音对讲、云台控制、电子地图、平台级联、H.265自动转码等…

开学有哪些好用电容笔值得买?ipad触控笔推荐平价

因为有了Apple Pencil,使得iPad就成了一款便携的生产力配件&#xff0c;其优势在于&#xff0c;电容笔搭配上iPad可以让专业的绘画师在iPad上作画&#xff0c;而且还能画出各种粗细不一的线条&#xff0c;对于有书写需求的学生党来讲&#xff0c;还是很有帮助的。但本人不敢想像…

基于CNN卷积神经网络的口罩检测识别系统matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ............................................................ % 循环处理每张输入图像 for…

PHP基础

PHP&#xff08;外文名:PHP:Hypertext Preprocessor&#xff0c;中文名&#xff1a;“超文本预处理器”&#xff09;是一种免费开源的、创建动态交互性站点的强有力的服务器端脚本语言 <h1>My Name is LiSi!</h1> <script>console.log("This message is…

星际争霸之小霸王之小蜜蜂(四)--事件监听-让小蜜蜂动起来

目录 前言 一、监听按键并作出判断 二、持续移动 三、左右移动 总结&#xff1a; 前言 今天开始正式操控我们的小蜜蜂了&#xff0c;之前学java的时候是有一个函数监听鼠标和键盘的操作&#xff0c;我们通过传过来不同的值进行判断&#xff0c;现在来看看python是否一样的实现…

深度学习最强奠基作ResNet《Deep Residual Learning for Image Recognition》论文解读(上篇)

1、摘要 1.1 第一段 作者说深度神经网络是非常难以训练的&#xff0c;我们使用了一个残差学习框架的网络来使得训练非常深的网络比之前容易得很多。 把层作为一个残差学习函数相对于层输入的一个方法&#xff0c;而不是说跟之前一样的学习unreferenced functions 作者提供了…

SRM系统询价竞价管理:优化采购流程的全面解析

SRM系统的询价竞价管理模块是现代企业采购管理中的重要工具。通过该模块&#xff0c;企业可以实现供应商的询价、竞价和合同管理等关键环节的自动化和优化。 一、概述 SRM系统是一种用于管理和优化供应商关系的软件系统。它通过集成各个环节&#xff0c;包括供应商信息管理、询…

算法leetcode|72. 编辑距离(rust重拳出击)

文章目录 72. 编辑距离&#xff1a;样例 1&#xff1a;样例 2&#xff1a;提示&#xff1a; 分析&#xff1a;题解&#xff1a;rust&#xff1a;二维数组&#xff08;易懂&#xff09;滚动数组&#xff08;更加优化的内存空间&#xff09; go&#xff1a;c&#xff1a;python&a…

MySQL 数据库存储引擎

一、存储引擎概念 数据库存储引擎是数据库底层软件组件&#xff0c;数据库管理系统--DBMS使用数据引擎进行创建、查询、更新和删除数据操作。不同得存储引擎提供不同得存储机制、索引技巧、锁定水平等功能&#xff0c;使用不同得存储引擎&#xff0c;还可以获得特定的功能。现…

快解析Linux搭建FTP服务器:轻松实现文件传输

在Linux操作系统中&#xff0c;搭建FTP服务器是一种常见且重要的操作。快解析提供了便捷的解决方案&#xff0c;帮助用户快速搭建FTP服务器&#xff0c;实现高效的文件传输和共享。本文将介绍Linux搭建FTP服务器的定义、作用以及其独特的优势&#xff0c;助您了解并利用这一强大…

A - Bone Collector(01背包)

Many years ago , in Teddy’s hometown there was a man who was called “Bone Collector”. This man like to collect varies of bones , such as dog’s , cow’s , also he went to the grave … The bone collector had a big bag with a volume of V ,and along his tr…

超越函数界限:探索JavaScript函数的无限可能

&#x1f3ac; 岸边的风&#xff1a;个人主页 &#x1f525; 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想&#xff0c;就是为了理想的生活 ! 目录 &#x1f4da; 前言 &#x1f4d8; 1. 函数的基本概念 &#x1f4df; 1.1 函数的定义和调用 &#x1f4df; 1.2 …