【人工智能前沿弄潮】——生成式AI系列:Diffusers应用 (2) 训练扩散模型(无条件图像生成)

无条件图像生成是扩散模型的一种流行应用,它生成的图像看起来像用于训练的数据集中的图像。与文本或图像到图像模型不同,无条件图像生成不依赖于任何文本或图像。它只生成与其训练数据分布相似的图像。通常,通过在特定数据集上微调预训练模型可以获得最佳结果。

本教程主要来自huggingface官方教程,结合一些自己的修改,以支持训练本地数据集。我们首先依据官方教程,利用史密森尼蝴蝶数据集的子集上从头开始训练UNet2DModel,以生我们自己的的🦋蝴蝶🦋。最后因为我是搞遥感方向的(测绘小卡拉米),所以利用遥感数据进行训练尝试,遥感影像使用的是煤矿区的无人机遥感影像,主要就是裸地和枯草,有的还有一些因为煤矿开采导致的地裂缝。

1、Train配置

为方便起见,创建一个包含训练超参数的TrainingConfig类(请随意调整它们):

from dataclasses import dataclass@dataclass
class TrainingConfig:image_size = 128  # the generated image resolutiontrain_batch_size = 16eval_batch_size = 16  # how many images to sample during evaluationnum_epochs = 50gradient_accumulation_steps = 1learning_rate = 1e-4lr_warmup_steps = 500save_image_epochs = 10save_model_epochs = 30mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precisionoutput_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hubpush_to_hub = True  # whether to upload the saved model to the HF Hubhub_private_repo = Falseoverwrite_output_dir = True  # overwrite the old model when re-running the notebookseed = 0config = TrainingConfig()

2、加载数据集

对于在hug 仓库空开的数据集可以使用🤗 Datasets依赖库轻松加载,比如本次的Smithsonian Butterflies:

from datasets import load_datasetconfig.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")

对于本地数据请用一下代码进行加载(请根据自己情况进行修改):

from datasets import load_datasetdata_dir = "/home/diffusers/datasets/isprsdataset"
dataset = load_dataset('imagefolder', data_dir=data_dir, split='train')

🤗 Datasets使用图像功能自动解码图像数据并将其加载为PIL. Image,我们可以将其可视化:

import matplotlib.pyplot as pltfig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):axs[i].imshow(image)axs[i].set_axis_off()
fig.show()

在这里插入图片描述


3、图像预处理

由于图像大小不同,所以需要先对其进行预处理,也就是常规的图像增强:

  • 调整大小将图像大小更改为配置文件中定义的图像大小—image_size
  • RandomHorizontalFlip通过随机镜像图像来增强数据集。
  • Normalize对于将像素值重新缩放到[-1,1]范围内很重要,这是模型所期望的。
from torchvision import transformspreprocess = transforms.Compose([transforms.Resize((config.image_size, config.image_size)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5835, 0.5820, 0.5841], [0.1149, 0.1111, 0.1064]), # isprs# transforms.Normalize([0.5], [0.5]),]
)

这里使用的是Pytorch自带的数据增强接口,这里我推荐大家使用albumentations数据增强库。

使用🤗Datasetsset_transform方法在训练期间动态应用预处理函数:

def transform(examples):images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}dataset.set_transform(transform)

现在将数据集包装在DataLoader中进行训练:

import torch
python
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)   

4、创建UNet2DModel

🧨 Diffusers 中的预训练模型可以使用您想要的参数从它们的模型类轻松创建。例如,要创建UNet2DModel

from diffusers import UNet2DModelmodel = UNet2DModel(sample_size=config.image_size,  # the target image resolutionin_channels=3,  # the number of input channels, 3 for RGB imagesout_channels=3,  # the number of output channelslayers_per_block=2,  # how many ResNet layers to use per UNet blockblock_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet blockdown_block_types=("DownBlock2D",  # a regular ResNet downsampling block"DownBlock2D","DownBlock2D","DownBlock2D","AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention"DownBlock2D",),up_block_types=("UpBlock2D",  # a regular ResNet upsampling block"AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D",),
)

检查样本图像形状与模型输出形状是否匹配:

sample_image = dataset[0]["images"].unsqueeze(0)
print("Input shape:", sample_image.shape)print("Output shape:", model(sample_image, timestep=0).sample.shape)

接下来创建一个scheduler为图像添加一些噪点。


5、创建scheduler

根据您是使用模型进行训练还是推理,scheduler的行为会有所不同。在推理期间,scheduler从噪声中生成图像。在训练期间,scheduler从扩散过程中的特定点获取模型输出或样本,并根据噪声时间表和更新规则(比如我们本系列第一张所说的step)将噪声应用于图像。(我们可以看到,遥感影像生成的结果还行,已经能明显的看清楚地表和枯草,甚至能够出现可看清的地裂缝!)

让我们看看DDPMScheduler并使用add_noise方法向之前的sample_image添加一些随机噪声:

import torch
from PIL import Image
from diffusers import DDPMSchedulernoise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])


在这里插入图片描述

模型的训练目标是预测添加到图像中的噪声。该步骤的损失可以通过以下方式计算,这里官方教程使用的是mse损失函数

import torch.nn.functional as Fnoise_pred = model(noisy_image, timesteps).sample
loss = F.mse_loss(noise_pred, noise)

6、训练模型

到目前为止,已经有了开始训练模型的大部分部分,剩下的就是把所有东西放在一起。 首先,您需要一个优化器和一个学习率调度器:

from diffusers.optimization import get_cosine_schedule_with_warmupoptimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=config.lr_warmup_steps,num_training_steps=(len(train_dataloader) * config.num_epochs),
)

然后,您需要一种评估模型的方法。对于评估,您可以使用DDPMPipeline生成一批示例图像并将其保存为网格格式(官方输出为格网,大家也可自行修改为单张保存):

from diffusers import DDPMPipeline
import math
import osdef make_grid(images, rows, cols):w, h = images[0].sizegrid = Image.new("RGB", size=(cols * w, rows * h))for i, image in enumerate(images):grid.paste(image, box=(i % cols * w, i // cols * h))return griddef evaluate(config, epoch, pipeline):# Sample some images from random noise (this is the backward diffusion process).# The default pipeline output type is `List[PIL.Image]`images = pipeline(batch_size=config.eval_batch_size,generator=torch.manual_seed(config.seed),).images# Make a grid out of the imagesimage_grid = make_grid(images, rows=2, cols=3)# Save the imagestest_dir = os.path.join(config.output_dir, "samples")os.makedirs(test_dir, exist_ok=True)image_grid.save(f"{test_dir}/{epoch + 1:04d}.png")

现在,可以使用🤗Accelerate将所有这些组件包装在一个训练循环中,以便于TensorBoard日志记录、梯度累积混合精度训练

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):# Initialize accelerator and tensorboard loggingaccelerator = Accelerator(mixed_precision=config.mixed_precision,gradient_accumulation_steps=config.gradient_accumulation_steps,log_with="tensorboard",project_dir=os.path.join(config.output_dir, "logs"),)if accelerator.is_main_process:if config.push_to_hub:repo_name = get_full_repo_name(Path(config.output_dir).name)repo = Repository(config.output_dir, clone_from=repo_name)elif config.output_dir is not None:os.makedirs(config.output_dir, exist_ok=True)accelerator.init_trackers("train_example")# Prepare everything# There is no specific order to remember, you just need to unpack the# objects in the same order you gave them to the prepare method.model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, lr_scheduler)global_step = 0# Now you train the modelfor epoch in range(config.num_epochs):progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)progress_bar.set_description(f"Epoch {epoch + 1}")for step, batch in enumerate(train_dataloader):clean_images = batch["images"]# Sample noise to add to the imagesnoise = torch.randn(clean_images.shape).to(clean_images.device)bs = clean_images.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()# Add noise to the clean images according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)with accelerator.accumulate(model):# Predict the noise residualnoise_pred = model(noisy_images, timesteps, return_dict=False)[0]loss = F.mse_loss(noise_pred, noise)accelerator.backward(loss)accelerator.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}progress_bar.set_postfix(**logs)accelerator.log(logs, step=global_step)global_step += 1# After each epoch you optionally sample some demo images with evaluate() and save the modelif accelerator.is_main_process:pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:print(f'----------------------------------------------------- Evaluate Iter [{(epoch + 1) // config.save_image_epochs}] ------------------------------------------------------------------')evaluate(config, epoch, pipeline)if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:pipeline.save_pretrained(config.output_dir)

接下来使用🤗Acceleratenotebook_launcher函数启动训练了。将训练循环、所有训练参数和进程数(可以将此值更改为可用于训练的GPU数)传递给该函数:

from accelerate import notebook_launcherargs = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)notebook_launcher(train_loop, args, num_processes=1)

训练完成后,看看扩散模型生成的最终🦋图像(🦋我隔10个epoch生成一次,在下面给大家瞅瞅)和遥感影像(因为我电脑的原因,遥感影像跑了一半停了,不过也保存了一些,感慨一下,扩散模型太吃显存了,比之前跑分割检测啥的更加依赖,可能是我图像整的太大了,之后裁小一点试一试,感觉生成模型用于遥感领域,又困难,也有无限可能!这只是一个简单的扩散生成示例模型,还得再深入研究研究,以后再和大家分享其他更新又有意思的生成模型。

import globsample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])

在这里插入图片描述

请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/89011.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jxls导出问题

![请添加图片描述](https://img-blog.csdnimg.cn/bc74c4207818491c93b75e19b3333451.png 为什么最后导出的文件还是按原样导出啊,没有填充数据 ![在这里插入图片描述](https://img-blog.csdnimg.cn/d4500b9a98c042f6b64a5d0650071303.png

云安全攻防(八)之 Docker Remote API 未授权访问逃逸

Docker Remote API 未授权访问逃逸 基础知识 Docker Remote API 是一个取代远程命令行界面(rcli)的REST API,其默认绑定2375端口,如管理员对其配置不当可导致未授权访问漏洞。攻击者利用 docker client 或者 http 直接请求就可以…

【PostgreSQL的CLOG解析】

同样还是这张图,之前发过shared_buffer和os cache、wal buffer和work mem的文章,今天的主题是图中的clog,即 commit log,PostgreSQL10之前放在数据库目录的pg_clog下面。PostgreSQL10之后修更名为xact,数据目录变更为pg_xact下面&…

Vue+SpringBoot项目开发:登录页面美化,登录功能实现(三)

写在开始:一个搬砖程序员的随缘记录上一章写了从零开始VueSpringBoot后台管理系统:Vue3TypeScript项目搭建 VueTypeScript的前端项目已经搭建完成了 这一章的内容是引入element-plus和axios实现页面的布局和前后端数据的串联,实现一个登陆的功能&#x…

CSS变形与动画(一):transform变形 与 transition过渡动画 详解(用法 + 代码 + 例子 + 效果)

文章目录 变形与动画transform 变形translate 位移scale 缩放rotate 旋转skew 倾斜多种变形设置变形中心点 transition 过渡动画多种属性变化 变形与动画 transform 变形 包括:位移、旋转、缩放、倾斜。 下面的方法都是transform里的,记得加上。 展示效…

Apache Maven:从构建到部署,一站式解决方案

目录 一、Maven介绍 1. Maven是什么? 2.Maven的作用? 二、Maven仓库介绍 2.1 库的分类 三、Maven安装与配置 3.1 Maven安装 3.2 Maven环境配置 3.3 仓库配置 四、Eclipse与Maven配置 五、Maven项目测试 5.1 新建Maven项目步骤及注意事项 5.…

C/C++test两步完成CMake项目静态分析

您可能一直在静态分析中使用CMake。但您是否尝试过将Parasoft C/Ctest与CMake一起使用吗?以下是如何使用C/Ctest在基于CMake的项目中运行静态分析的详细说明。 CMake是用于构建、测试和打包软件的最流行的工具之一。Parasoft C/Ctest通过简化构建管理过程&#xff…

RabbitMQ基础(2)——发布订阅/fanout模式 topic模式 rabbitmq回调确认 延迟队列(死信)设计

目录 引出点对点(simple)Work queues 一对多发布订阅/fanout模式以登陆验证码为例pom文件导包application.yml文件rabbitmq的配置生产者生成验证码,发送给交换机消费者消费验证码 topic模式配置类增加配置生产者发送信息进行发送控制台查看 rabbitmq回调确认配置类验…

Redis_缓存1_缓存类型

14.redis缓存 14.1简介 穿透型缓存: 缓存与后端数据交互在一起,对服务端的调用隐藏细节。如果从缓存中可以读到数据,就直接返回,如果读不到,就到数据库中去读取,从数据库中读到数据,也是先更…

制造执行系统(MES)在新能源领域的应用

制造执行系统(MES)在新能源领域有许多应用,特别是在管理、监控和优化新能源生产过程方面。新能源包括太阳能、风能、生物质能、地热能等。以下是一些MES在新能源方面的应用领域: 生产计划与调度:MES可以协助规划和调度…

谷粒商城第十一天-品牌管理中关联分类

目录 一、总述 二、前端部分 1. 调整查询调用 2. 关联分类 三、后端部分 四、总结 一、总述 之前是在商品的分类管理中直接使用的若依的逆向代码 有下面的几个问题: 1. 表格上面的参数填写之后,都是按照完全匹配进行搜索,没有模糊匹配…

计算机网络—HTTP

这里写目录标题 HTTP是什么HTTP常见状态码HTTP常见字段GET与POST的区别Get和Post是安全和幂等吗PUT幂等,不安全DELETE幂等,不是安全 HTTP缓存技术HTTP缓存实现技术 HTTP1.0优缺点和性能HTTP1.1优缺点和性能HTTP2优缺点和性能HTTP3优缺点和性能HTTP和HTTP…

vuex学习总结

一、vuex工作原理 工作流程:需求:改变组件count的sun变量的值,先调用dispatch函数传入jia函数和要改变的值给actions(这个actions里面必须有jia这个函数);actions收到后调用commit函数将jia方法和值传给mut…

做BI领域的ChatGPT,思迈特升级一站式ABI平台

8月8日,以「指标驱动 智能决策」为主题,2023 Smartbi V11系列新品发布会在广州丽思卡尔顿酒店开幕。 ​ 后疫情时代,BI发展趋势的观察与应对 在发布会上,思迈特CEO吴华夫在开场致辞中表示,当前大环境背景下&#xf…

Stable Diffusion教程(9) - AI视频转动漫

配套抖音视频教程:https://v.douyin.com/UfTcrcJ/ 安装mov2mov插件 打开webui点击扩展->从网址安装输入地址,然后点击安装 https://github.com/Scholar01/sd-webui-mov2mov 最后重启webui 下载模型 从国内liblib AI 模型站下载模型 LiblibAI哩…

已有公司将ChatGPT集成到客服中心以增强用户体验

Ozonetel正在利用ChatGPT来改善客户体验。该公司表示,他们通过使用ChatGPT收集与客户互动过程收集的“语料”能够更有针对性地提高服务效率,提供个性化的用户体验,并实现更高的客户满意度。[1] 通过这套解决方案,客服中心将拥有一…

办理流量卡也是有条件的,这五种情况就不能办理流量卡!

流量卡资费虽然便宜,但也不是谁都可以办得,以下这几种情况是办不了的! 看到网上的流量卡资费便宜,也想随手申请一张,别想得太简单了,流量卡也不是那么好办理的,换句话来讲,办理流量…

【量化课程】07_量化回测

文章目录 7.1 pandas计算策略评估指标数据准备净值曲线年化收益率波动率最大回撤Alpha系数和Beta系数夏普比率信息比率 7.2 聚宽平台量化回测实践平台介绍策略实现 7.3 Backtrader平台量化回测实践Backtrader简介Backtrader量化回测框架实践 7.4 BigQuant量化框架实战BigQuant简…

特语云用Linux和MCSM面板搭建 我的世界基岩版插件服 教程

Linux系统 用MCSM和DockerWine 搭建 我的世界 LiteLoaderBDS 服务器 Minecraft Bedrock Edition 也就是我的世界基岩版,这是 Minecraft 的另一个版本。Minecraft 基岩版可以运行在 Win10、Android、iOS、XBox、switch。基岩版不能使用 Java 版的服务器,…

Spring BeanPostProcessor 接口的作用和使用

BeanPostProcessor 接口是 Spring 框架中的一个扩展接口,用于在 Spring 容器实例化、配置和初始化 bean 的过程中提供自定义的扩展点。通过实现这个接口,您可以在 bean 实例创建的不同生命周期阶段插入自己的逻辑,从而实现对 bean 行为的定制…