从头开始制作扩散模型(实现快速扩散模型的简单方法)

一、说明

        本文是关于自己从头开始构建扩散模型的教程。我总是喜欢让事情变得简单易行,所以在这里,我们避免了复杂的数学。这不是一个正常的扩散模型。相反,我称之为快速扩散模型。将仅使用卷积神经网络(CNN)来制作扩散模型。在本文中,我不会为您提供任何现有的模型/权重/脚本文件。


您需要自己训练模型。
(我们正在使用TensorFlow提供的CIFAR-10数据集。

你可以在我的 GitHub
https://github.com/Seachaos/Tree.Rocks/blob/main/QuickDiffusionModel/QuickDiffusionModel.ipynb中找到代码

二、这个想法

        这就是扩散模型的工作原理:它就像基于一个完全嘈杂的图像,并逐渐提高图像质量,直到它变得清晰。
(如下图所示)

扩散模型示例改善了图像

        因此,我们可以创建一个深度学习模型,可以提高图像质量(从全噪声到清晰的图像),流程思想:

快速扩散模型流程

        为了更清晰地了解,请查看此附加流程图。

图像在扩散模型中的流动方式

        如上图所示,该模型正在尝试生成噪声逐渐减少的图像。现在,我们只需要训练一个深度学习模型来学习如何减少噪音。
        对于该任务,我们需要模型中的两个输入:

  • 输入图像 — 需要处理噪声图像
  • 时间戳 — 告诉模型什么是噪声状态,以便更容易学习

三、实现快速扩散模型

        首先,让我们导入我们需要的内容:

import numpy as npfrom tqdm.auto import trange, tqdm
import matplotlib.pyplot as pltimport tensorflow as tf
from tensorflow.keras import layers

        并准备我们的数据集, 在本教程中,我们将使用大量汽车图像(CIFAR-10)作为示例,以使事情尽可能简单快捷。
(但是,如果您有足够的样本,则可以选择您喜欢的任何图像。

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train[y_train.squeeze() == 1]
X_train = (X_train / 127.5) - 1.0

        接下来,让我们定义变量。

IMG_SIZE = 32     # input image size, CIFAR-10 is 32x32
BATCH_SIZE = 128  # for training batch size
timesteps = 16    # how many steps for a noisy image into clear
time_bar = 1 - np.linspace(0, 1.0, timesteps + 1) # linspace for timesteps

        在这里,我们设置“时间步长”,这意味着我们的模型将学习通过训练过程生成从嘈杂(级别 0)到清晰(级别 16)的图像。

让我们看一张图片以获得更清晰的想法

plt.plot(time_bar, label='Noise')
plt.plot(1 - time_bar, label='Clarity')
plt.legend()
图像噪点和清晰度随时间步长的变化

        如您所见,从时间步长 0 到 16,噪音减少,清晰度逐渐提高。这就是我们希望我们的模型学习的内容。

        并为预览数据准备一些功能

def cvtImg(img):img = img - img.min()img = (img / img.max())return img.astype(np.float32)def show_examples(x):plt.figure(figsize=(10, 10))for i in range(25):plt.subplot(5, 5, i+1)img = cvtImg(x[i])plt.imshow(img)plt.axis('off')show_examples(X_train)

CIFAR-10 汽车

3.1 培训准备

        在这里,我们需要准备训练图像的代码。

        这个想法是从随机时间点获得两个图像(A和B),其中A是噪声图像,B是更清晰的图像。
我们的模型将学习根据该特定时间点将A转换为B(从嘈杂到更清晰)。
(再次作为此图)

图像 A 在上面,图像 B 在下面

        因此,我们在这里forward_noise功能。

def forward_noise(x, t):a = time_bar[t]      # base on tb = time_bar[t + 1]  # image for t + 1noise = np.random.normal(size=x.shape)  # noise maska = a.reshape((-1, 1, 1, 1))b = b.reshape((-1, 1, 1, 1))img_a = x * (1 - a) + noise * aimg_b = x * (1 - b) + noise * breturn img_a, img_bdef generate_ts(num):return np.random.randint(0, timesteps, size=num)# t = np.full((25,), timesteps - 1) # if you want see clarity
# t = np.full((25,), 0)             # if you want see noisy
t = generate_ts(25)             # random for training data
a, b = forward_noise(X_train[:25], t)
show_examples(a)

        如果你想了解它是如何工作的,我建议运行我注释掉的代码。( t = ... )

预览训练数据示例

3.2 构建 CNN 块

        我们将使用 U-Net 作为我们的模型,详细信息将在下面的代码中解释。

        模型架构,详细内容会在后面的代码中讲解,在构建模型之前,我们需要先定义块。
        这是 make 块的代码:

def block(x_img, x_ts):x_parameter = layers.Conv2D(128, kernel_size=3, padding='same')(x_img)x_parameter = layers.Activation('relu')(x_parameter)time_parameter = layers.Dense(128)(x_ts)time_parameter = layers.Activation('relu')(time_parameter)time_parameter = layers.Reshape((1, 1, 128))(time_parameter)x_parameter = x_parameter * time_parameter# -----x_out = layers.Conv2D(128, kernel_size=3, padding='same')(x_img)x_out = x_out + x_parameterx_out = layers.LayerNormalization()(x_out)x_out = layers.Activation('relu')(x_out)return x_out

        每个块包含两个带有时间参数的卷积网络,允许网络确定其当前的时间步长并输出相应的信息。
        您可以看到块流程图:
                (x_img 是输入图像,是噪声图像,x_ts 是时间步长的输入)

块的流向

搭建模型,现在我们可以构建我们的模型

def make_model():x = x_input = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='x_input')x_ts = x_ts_input = layers.Input(shape=(1,), name='x_ts_input')x_ts = layers.Dense(192)(x_ts)x_ts = layers.LayerNormalization()(x_ts)x_ts = layers.Activation('relu')(x_ts)# ----- left ( down ) -----x = x32 = block(x, x_ts)x = layers.MaxPool2D(2)(x)x = x16 = block(x, x_ts)x = layers.MaxPool2D(2)(x)x = x8 = block(x, x_ts)x = layers.MaxPool2D(2)(x)x = x4 = block(x, x_ts)# ----- MLP -----x = layers.Flatten()(x)x = layers.Concatenate()([x, x_ts])x = layers.Dense(128)(x)x = layers.LayerNormalization()(x)x = layers.Activation('relu')(x)x = layers.Dense(4 * 4 * 32)(x)x = layers.LayerNormalization()(x)x = layers.Activation('relu')(x)x = layers.Reshape((4, 4, 32))(x)# ----- right ( up ) -----x = layers.Concatenate()([x, x4])x = block(x, x_ts)x = layers.UpSampling2D(2)(x)x = layers.Concatenate()([x, x8])x = block(x, x_ts)x = layers.UpSampling2D(2)(x)x = layers.Concatenate()([x, x16])x = block(x, x_ts)x = layers.UpSampling2D(2)(x)x = layers.Concatenate()([x, x32])x = block(x, x_ts)# ----- output -----x = layers.Conv2D(3, kernel_size=1, padding='same')(x)model = tf.keras.models.Model([x_input, x_ts_input], x)return modelmodel = make_model()
# model.summary()

这是一个U-Net,左、右、MLP部分可以参考上图(模型架构)。

不要忘记编译模型

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)
loss_func = tf.keras.losses.MeanAbsoluteError()
model.compile(loss=loss_func, optimizer=optimizer)

        我们使用 Adam 作为优化器,使用 MeanAbsoluteError (MAE) 作为损失函数。

        预测结果:我们现在可以尝试我们的第一个预测。预测步骤如下:

  1. 创建嘈杂的图像
  2. 以时间步长输入到我们的模型中
  3. 继续这样做直到时间步结束

        所以这是这个函数:

def predict(x_idx=None):x = np.random.normal(size=(32, IMG_SIZE, IMG_SIZE, 3))for i in trange(timesteps):t = ix = model.predict([x, np.full((32), t)], verbose=0)show_examples(x)predict()

        未经训练的模型输出图像 上面是我们的未经训练的模型输出,如您所见,它没有任何用处。 这个函数还可以帮助我们查看每个步骤:

def predict_step():xs = []x = np.random.normal(size=(8, IMG_SIZE, IMG_SIZE, 3))for i in trange(timesteps):t = ix = model.predict([x, np.full((8),  t)], verbose=0)if i % 2 == 0:xs.append(x[0])plt.figure(figsize=(20, 2))for i in range(len(xs)):plt.subplot(1, len(xs), i+1)plt.imshow(cvtImg(xs[i]))plt.title(f'{i}')plt.axis('off')predict_step()
未经训练的模型输出步骤

四、训练模型

        这个训练功能很简单

def train_one(x_img):x_ts = generate_ts(len(x_img))x_a, x_b = forward_noise(x_img, x_ts)loss = model.train_on_batch([x_a, x_ts], x_b)return loss

        我们只需要提供x_tsx_img(x_a),使我们的模型能够学习如何生成x_b。

        并使其成为纪元函数

def train(R=50):bar = trange(R)total = 100for i in bar:for j in range(total):x_img = X_train[np.random.randint(len(X_train), size=BATCH_SIZE)]loss = train_one(x_img)pg = (j / total) * 100if j % 5 == 0:bar.set_description(f'loss: {loss:.5f}, p: {pg:.2f}%')

        最后,多次运行并逐渐降低学习率

for _ in range(10):train()# reduce learning rate for next trainingmodel.optimizer.learning_rate = max(0.000001, model.optimizer.learning_rate * 0.9)# show result predict()predict_step()plt.show()

        你可以得到一些这样的输出图像

快速扩散模型输出示例

五、结论

        本教程设计简单,允许您进行实验。您可以尝试自己的参数(如更改图像大小,CNN过滤器,时间步长或MLP等)和更多的时期训练以获得更好的结果。海沌

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/133369.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

API(九)基于协程的并发编程SDK

一 基于协程的并发编程SDK 场景: 收到一个请求会并发发起多个请求,使用openresty提供的协程说明: 这个是高级课程,如果不理解可以先跳过遗留: APSIX和Kong深入理解openresty 标准lua的协程 ① 早期提供的轻量级协程SDK ngx.thread ngx…

数据结构——【堆】

一、堆的相关概念 1.1、堆的概念 1、堆在逻辑上是一颗完全二叉树(类似于一颗满二叉树只缺了右下角)。 2、堆的实现利用的是数组,我们通常会利用动态数组来存放元素,这样可以快速拓容也不会很浪费空间,我们是将这颗完…

【C++】构造函数调用规则 ( 默认构造函数 | 默认无参构造函数 | 默认拷贝构造函数 | 构造函数调用规则说明 )

文章目录 一、默认构造函数1、默认无参构造函数2、默认拷贝构造函数 二、构造函数调用规则1、构造函数规则说明2、代码示例 - 只定义拷贝构造函数3、代码示例 - 只定义有参构造函数 一、默认构造函数 C 类中 2 种特殊的构造函数 , 分别是 : 默认无参构造函数 : 如果 C 类中 没…

PyTorch实现注意力机制及使用方法汇总,附30篇attention论文

还记得鼎鼎大名的《Attention is All You Need》吗?不过我们今天要聊的重点不是transformer,而是注意力机制。 注意力机制最早应用于计算机视觉领域,后来也逐渐在NLP领域广泛应用,它克服了传统的神经网络的的一些局限&#xff0c…

sqlserver存储过程报错:当前事务无法提交,而且无法支持写入日志文件的操作。请回滚该事务。

现象: 系统出现异常,手动执行过程提示如上。 问题排查: 1.直接执行的过程事务挂起(排除) 2.重启数据库实例(重启后无效) 3.过程中套用过程,套用的过程中使用事务,因为…

STM32-HAL库06-硬件IIC驱动FM24CL16B非易失存储器

STM32-HAL库06-IIC驱动FM24CL16B非易失存储器 一、所用材料: STM32VGT6自制控制板 STM32CUBEMX(HAL库软件) MDK5 二、所学内容: 通过HAL库的硬件IIC对FM24CL16B存储器进行写与读取操作。 三、CUBEMX配置: 第一步…

Virtualbox中Ubuntu根目录空间不足

现象 Virtualbox中Ubuntu根目录空间不足 解决 动态存储 虚拟机关闭先在虚拟介质管理里把硬盘Size调大开启Ubuntu用Disks或者GParted重新调整分区大小重新启动 步骤参考: https://zhuanlan.zhihu.com/p/319431032 https://blog.csdn.net/ningmengzhihe/article/details/1272…

Java 毕业设计-基于SpringBoot的在线文档管理系统

基于SpringBoot的在线文档管理系统 博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W,Csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 技术栈简介 文末获取源码 开发语言:Java 框架:sp…

rocketmq

🍓代码仓库 https://gitee.com/xuhx615/rocket-mqdemo.git 🍓基本概念 ⭐生产者(Producer):消息发布者⭐主题(Topic):topic用于标识同一类业务类型的消息⭐消息队列(MessageQueue&#xff09…

VirtualBox宿主机和虚拟机文件互传设置

一、如图1、2、3步骤,设置共享粘贴板和拖放为双向 二、 在启动的虚拟机设置的里面,安装增强插件,然后重启虚拟机。 三、在网络位置就可以看到了

Java基于SpringBoot的闲一品交易平台

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W,Csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 大家好,我是程序员徐师兄、今天给大家谈谈基于android的app开发毕设题目,以及基于an…

学习Bootstrap 5的第八天

目录 加载器 彩色加载器 实例 闪烁加载器 实例 加载器大小 实例 加载器按钮 实例 分页 分页的基本结构 实例 活动状态 实例 禁用状态 实例 分页大小 实例 分页对齐 实例 面包屑(Breadcrumbs) 实例 加载器 彩色加载器 在 Bootstr…

竞赛 基于情感分析的网络舆情热点分析系统

文章目录 0 前言1 课题背景2 数据处理3 文本情感分析3.1 情感分析-词库搭建3.2 文本情感分析实现3.3 建立情感倾向性分析模型 4 数据可视化工具4.1 django框架介绍4.2 ECharts 5 Django使用echarts进行可视化展示5.1 修改setting.py连接mysql数据库5.2 导入数据5.3 使用echarts…

GeoSOS-FLUS未来土地利用变化情景模拟模型

软件简介 适用场景 GeoSOS-FLUS软件能较好的应用于土地利用变化模拟与未来土地利用情景 的预测和分析中,是进行地理空间模拟、参与空间优化、辅助决策制定的有效工 具。FLUS 模型可直接用于: 城市发展模拟及城市增长边界划定;城市内 部高分…

Debian 12快速安装图解

文章目录 Debian 12安装图解创建虚拟机安装系统登录并用光盘离线安装sudo、curl解决Linux下sudo更改文件权限报错保存快照debain添加在线源(配置清华源)参考 Debian 12安装图解 Debian选择CD安装非常慢,本次安装选择DVD离线安装。 下载 https://www.debian.org/CD…

Swift如何使用Vision来识别获取图片中的文字(OCR),通过SwiftUI视图和终端命令行,以及一系列注意事项

在过去的一年里,我发现苹果系统中的“文字搜图片”功能非常好用,这个功能不光 iPhone/iPad,Mac 也有,找一些图片真的很好用。但是遇到了一个问题:这个功能需要一段时间才能找到新的图片,而且没法手动刷新&a…

从一到无穷大 #15 Gorilla,论黄金26H与时序数据库缓存系统的可行性

本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 本作品 (李兆龙 博文, 由 李兆龙 创作),由 李兆龙 确认,转载请注明版权。 引言 缓存系统的高效存在前提,在满足前提的情况下可以接受缺陷便没有理由不引入缓…

pdf添加水印

给pdf文件添加水印 引入依赖 <dependency><groupId>com.itextpdf</groupId><artifactId>itextpdf</artifactId><version>5.5.13.3</version></dependency>添加水印 package com.it2.pdfdemo02.util;import com.itextpdf.tex…

Qt应用程序连接达梦数据库-飞腾PC麒麟V10

目录 前言1 安装ODBC1.1 下载unixODBC源码1.2 编译安装1.4 测试 2 编译QODBC2.1 修改 qsqldriverbase.pri 文件2.2 修改 odbc.pro 文件2.3 编译并安装QODBC 3 Qt应用程序连接达梦数据库测试4 优化ODBC配置&#xff0c;方便程序部署4.1 修改pro文件&#xff0c;增加DESTDIR 变量…

高可用Kuberbetes部署Prometheus + Grafana

概述 阅读官方文档部署部署Prometheus Grafana GitHub - prometheus-operator/kube-prometheus at release-0.10 环境 步骤 下周官方github仓库 git clone https://github.com/prometheus-operator/kube-prometheus.git git checkout release-0.10 进入工作目录 cd kube…