AI绘画训练一个扩散模型-上集

介绍

AI绘画,其中最常见方案基于扩散模型,Stable Diffusion 在此基础上,增加了 VAE 模块和 CLIP 模块,本文搞了一个测试Demo,分为上下两集,第一集是denoising_diffusion_pytorch ,第二集是diffusers。
对于专业的算法同学而言,我更推荐使用 diffusers 来训练。原因是 diffusers 工具包在实际的 AI 绘画项目中用得更多,并且也更易于我们修改代码逻辑,实现定制化功能。
https://arxiv.org/abs/2112.10752

基础模块

  • 创建UNet模型和高斯扩散模型(Gaussian Diffusion)。

UNet是一个编码器-解码器结构的全卷积神经网络。Gaussian Diffusion用于定义噪声过程和损失函数。

  • 将模型加载到GPU上(如果有GPU)。

  • 使用随机初始化的图片进行一次训练,计算损失并反向传播。

这一步的目的是对模型进行一次预热,更新权重。

  • 使用diffusion模型采样生成图片。

这里采样1000步,也就是将噪声逐步减少,每步用UNet预测下一步的图像,最终输出生成的图片。

  • 如果图片在GPU上,将其移回到CPU。

  • 可视化第一张生成图片。

plt.imshow显示图片。

这样通过DDPM框架,可以从随机噪声生成符合数据分布的新图片。每次训练会使模型逐步逼近真实数据分布,从而产生更高质量的图片。

# 创建UNet和扩散模型from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torchmodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # number of steps
).cuda()# 使用随机初始化的图片进行一次训练
training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images.cuda())
loss.backward()# 采样1000步生成一张图片
sampled_images = diffusion.sample(batch_size = 4)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms# 如果张量在 GPU上,需要移动到 CPU上
if sampled_images.is_cuda:sampled_images = sampled_images.cpu()# 检查我们生成的一张图
img = sampled_images[0].clone().detach().permute(1, 2, 0)plt.imshow(img)

数据集

  • 导入所需的库:PIL、io、datasets等。

  • 使用datasets库中的load_dataset方法加载Oxford Flowers数据集。

  • 创建一个目录来保存图片。

  • 遍历数据集的训练、验证、测试split,逐个图像获取图片bytes数据,并保存为PNG格式图片。

  • 使用PIL库的Image对象将bytes数据加载并保存为图片文件。

  • 使用tqdm显示循环进度。

# 数据集下载
from PIL import Image
from io import BytesIO
from datasets import load_dataset
import os
from tqdm import tqdmdataset = load_dataset("nelorth/oxford-flowers")# 创建一个用于保存图片的文件夹
images_dir = "./oxford-datasets/raw-images"
os.makedirs(images_dir, exist_ok=True)# 遍历所有图片并保存,针对oxford-flowers,整个过程要持续15分钟左右
for split in dataset.keys():for index, item in enumerate(tqdm(dataset[split])):image = item['image']image.save(os.path.join(images_dir, f"{split}_image_{index}.png"))

模型训练

  • 定义Unet模型架构和Gaussian Diffusion过程。

  • 加载数据,指定训练参数:

    • 训练总步数20000
    • batch size 16
    • 学习率2e-5
    • 梯度累积步数2
    • EMA指数衰减参数0.995
    • 使用混合精度训练
    • 每2000步保存一次模型
    • 创建Trainer进行模型训练。Trainer封装了训练循环逻辑。
  • 调用trainer.train()进行训练。

# 模型训练
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainermodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # 加噪总步数
).cuda()trainer = Trainer(diffusion,'./oxford-datasets/raw-images',train_batch_size = 16,train_lr = 2e-5,train_num_steps = 20000,          # 总共训练20000步gradient_accumulate_every = 2,    # 梯度累积步数ema_decay = 0.995,                # 指数滑动平均decay参数amp = True,                       # 使用混合精度训练加速calculate_fid = False,            # 我们关闭FID评测指标计算(比较耗时)。FID用于评测生成质量。save_and_sample_every = 2000      # 每隔2000步保存一次模型
)trainer.train()
# 你可以等待上面的模型训练完成后,查看生成结果from glob import globresult_images = glob(r"./results/*.png")
print(result_images)
# 可视化图像看看
from PIL import Imageimg = Image.open("./results/sample-1.png")
img

测试

https://colab.research.google.com/github/NightWalker888/ai_painting_journey/blob/main/lesson12/train_diffusion_v2.ipynb#scrollTo=8BVjfBPI93Ar

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/224033.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库开发之图形化工具以及表操作的详细解析

2.3 图形化工具 2.3.1 介绍 前面我们讲解了DDL中关于数据库操作的SQL语句,在我们编写这些SQL时,都是在命令行当中完成的。大家在练习的时候应该也感受到了,在命令行当中来敲这些SQL语句很不方便,主要的原因有以下 3 点&#xff…

截断整型提升算数转换

文章目录 🚀前言🚀截断🚀整型提升✈️整型提升是怎样的 🚀算术转换 🚀前言 大家好啊!这里阿辉补一下前面操作符遗漏的地方——截断、整型提升和算数转换 看这一篇要先会前面阿辉讲的数据的存储否则可能看不…

Dijkstra(迪杰斯特拉)算法总结

知识概览 Dijkstra算法适用于解决所有边权都是正数的最短路问题。Dijkstra算法分为朴素的Dijkstra算法和堆优化版的Dijkstra算法。朴素的Dijkstra算法时间复杂度为,适用于稠密图。堆优化版的Dijkstra算法时间复杂度为,适用于稀疏图。稠密图的边数m和是一…

React学习计划-React16--React基础(五)脚手架创建项目、todoList案例、配置代理、消息订阅与发布

一、使用脚手架create-react-app创建项目 react脚手架 xxx脚手架:用来帮助程序员快速创建一个基于xxx库的模板项目 包含了所有需要的配置(语法检查、jsx编译、devServe…)下载好了所有相关的依赖可以直接运行一个简单的效果 react提供了一个…

产品设计 之 创建完美产品需求文档的4个核心要点

客户描述他们想要的产品和最终交付的产品之间的误解一般很大,设计者和客户的角度不同,理解的程度也不同,就需要一个统一的交流中介。这里包含PRD。 为了说明理解误差的问题。下面这张有趣的图画可以精准阐述。 第一张图片展示了客户所描述…

Matlab仿真OOK、2FSK、2PSK、QPSK、4QAM在加性高斯白噪声信道中的误码率与归一化信噪比的关系

本文为学习所用,严禁转载。 本文参考链接 https://zhuanlan.zhihu.com/p/667382398 QPSK代码及高斯白噪声如何产生 https://ww2.mathworks.cn/help/signal/ref/butter.html 滤波器 https://www.python100.com/html/4LEF79KQK398.html 低通滤波器 本实验使用matlab仿…

【linux提权】利用setuid进行简单提权

首先先来了解一下setuid漏洞: SUID (Set UID)是Linux中的一种特殊权限,其功能为用户运行某个程序时,如果该程序有SUID权限,那么程序运行为进程时,进程的属主不是发起者,而是程序文件所属的属主。但是SUID权限的设置只…

「微服务模式」七种微服务反模式

什么是微服务 流行语经常为进化的概念提供背景,并且需要一个良好的“标签”来促进对话。微服务是一个新的“标签”,它定义了我个人一直在发现和使用的领域。文章和会议描述了一些事情,我慢慢意识到,过去几年我一直在发展自己的个人…

2023航天推进理论基础考试划重点(W老师)-液体火箭发动机1

适用于期末周求生欲满满的西北工业大学学生。 1、液体火箭发动机的基本组成及功能是什么? 推力室组件、推进剂供应系统、阀门与调节器、发动机总装元件等组成。 2、液体火箭发动机的分类和应用是什么?3、液体火箭发动机系统、分系统的概念是什么&…

交友系统设计:哪种地理空间邻近算法更快?

小熊学Java:https://javaxiaobear.cn 交友与婚恋是人们最基本的需求之一。随着互联网时代的不断发展,移动社交软件已经成为了人们生活中必不可少的一部分。然而,熟人社交并不能完全满足年轻人的社交与情感需求,于是陌生人交友平台…

vue3(六)-基础入门之自定义组件与插槽、ref通信

一、全局组件 html: <div id"app"><mytemplace></mytemplace> </div>javascript: <script>const { createApp } Vueconst app createApp({})app.component(mytemplace, {template: <div><button>返回</button>…

RPC 实战与原理

文章目录 什么是 RPC&#xff1f;RPC 有什么作用&#xff1f;RPC 步骤为什么需要序列化&#xff1f;零拷贝什么是零拷贝&#xff1f;为什么需要零拷贝&#xff1f;如何实现零拷贝&#xff1f;Netty 的零拷贝有何不同&#xff1f; 动态代理实现HTTP/2 特性为什么需要服务发现&am…

7. 结构型模式 - 代理模式

亦称&#xff1a; Proxy 意图 代理模式是一种结构型设计模式&#xff0c; 让你能够提供对象的替代品或其占位符。 代理控制着对于原对象的访问&#xff0c; 并允许在将请求提交给对象前后进行一些处理。 问题 为什么要控制对于某个对象的访问呢&#xff1f; 举个例子&#xff…

Python - 深夜数据结构与算法之 Graph

目录 一.引言 二.图的简介 1.Graph 图 2.Undirected graph 无向图 3.Directed Graph 有向图 4.DFS / BFS 遍历 三.经典算法实战 1.Num-Islands [200] 2.Land-Perimeter [463] 3.Largest-Island [827] 四.总结 一.引言 Graph 无论是应用还是算法题目在日常生活中比较…

方舟开发框架(ArkUI)概述

目录 1、基本概念 2、两种开发范式 3、开发框架的特性 4、UI开发&#xff08;ArkTS声明式开发范式&#xff09;概述 4.1、特点 4.2、整体架构 4.3、开发流程 方舟开发框架&#xff08;简称ArkUI&#xff09;为HarmonyOS应用的UI开发提供了完整的基础设施&#xff0c;包…

代驾系统开发:驶向未来的智能交通服务

随着科技的迅速发展&#xff0c;代驾系统的开发成为改善出行体验和提升交通服务智能化的重要一环。本文将聚焦于代驾系统开发的技术创新&#xff0c;为读者呈现其中涉及的一些令人振奋的技术代码。 1. 区块链技术的运用&#xff1a; 区块链技术被引入代驾系统&#xff0c;可…

智能优化算法应用:基于广义正态分布算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于广义正态分布算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于广义正态分布算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.广义正态分布算法4.实验参数设定…

轻量Http客户端工具VSCode和IDEA

文章目录 前言Visual Studio Code 的插件 REST Client编写第一个案例进阶&#xff0c;设置变量进阶&#xff0c;设置Token IntelliJ IDEA 的 HTTP请求构建http脚本HTTP的环境配置结果值暂存 前言 作为一个WEB工程师&#xff0c;在日常的使用过程中&#xff0c;HTTP请求是必不可…

SLAM算法与工程实践——SLAM基本库的安装与使用(6):g2o优化库(4)构建g2o的边

SLAM算法与工程实践系列文章 下面是SLAM算法与工程实践系列文章的总链接&#xff0c;本人发表这个系列的文章链接均收录于此 SLAM算法与工程实践系列文章链接 下面是专栏地址&#xff1a; SLAM算法与工程实践系列专栏 文章目录 SLAM算法与工程实践系列文章SLAM算法与工程实践…

在MacOS上Qt配置OpenCV并进行测试

目录 一.Qt环境准备 二.在Qt项目中加载Opencv库并编写代码测试 1.使用Opencv加载图片 &#xff08;1&#xff09;在Qt中创建一个新项目 &#xff08;2&#xff09;在.pro文件中链接OpenCV库 &#xff08;3&#xff09;添加新资源文件 &#xff08;4&#xff09;在mainw…