扩散模型实战(八):微调扩散模型

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

扩散模型实战(六):Diffusers DDPM初探

扩散模型实战(七):Diffusers蝴蝶图像生成实战

       微调在LLM中并不是新鲜的概念,从头开始训练一个扩散模型需要很长的时间,特别是使用高分辨率图像训练。那么其实我们可以在已经训练好的”去噪“扩散模型基础上使用微调数据集进行二次微调训练。

       本文将介绍基于蝴蝶数据集上微调人脸生成的扩散模型:

一、环境准备

1.1 安装相关库

!pip install -qq diffusers datasets accelerate wandb open-clip-torch

1.2 登录Huggingface Hub

如果需要开源微调好的模型到Huggingface Hub上,那么需要使用如下代码登录,否则可忽略此步骤:

from huggingface_hub import notebook_loginnotebook_login()

1.3 导入相关库

import numpy as npimport torchimport torch.nn.functional as Fimport torchvisionfrom datasets import load_datasetfrom diffusers import DDIMScheduler, DDPMPipelinefrom matplotlib import pyplot as pltfrom PIL import Imagefrom torchvision import transformsfrom tqdm.auto import tqdmdevice = (    "mps"    if torch.backends.mps.is_available()    else "cuda"    if torch.cuda.is_available()    else "cpu")

二、导入预训练的扩散模型

下面我们导入人脸生成的扩散模型,观察一下生成的效果,代码如下:

image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")image_pipe.to(device);

查看生成的图像,代码如下:

images = image_pipe().imagesimages[0]

       生成的效果虽然不错,但是速度稍微有点慢,其实有更快的采样器可以加速这一过程,比如下面介绍的DDIM

三、DDIM-更快的采样器

       在生成图像的每一步中,模型都会接收一个带有噪声的输入,并且需要预测这个噪声,以此来估计没有噪声的完整图像是什么。这个过程被称为采样过程,在Diffusers库中,采样通过调度器控制的,之前的文章中介绍过DDPMScheduler调度器,本文介绍的DDIMScheduler可以通过更少的迭代周期来产生很好的采样样本(1000多步采样不是必须的)。

# 创建一个新的调度器并设置推理迭代次数scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")scheduler.set_timesteps(num_inference_steps=40)
scheduler.timesteps
# 输出tensor([975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725,         700, 675, 650, 625, 600, 575, 550, 525, 500, 475, 450, 425,         400, 375, 350, 325, 300, 275, 250, 225, 200, 175, 150,         125, 100,  75,  50,  25, 0])

       下面使用4幅随机噪声图像进行循环采样,并观察每一步的输入与输出的”去噪“图像,代码如下:

# 从随机噪声开始x = torch.randn(4, 3, 256, 256).to(device) # batch size为4,三通道,长、宽均为256像素的一组图像# 循环一整套时间步for i, t in tqdm(enumerate(scheduler.timesteps)):     # 准备模型输入:给“带躁”图像加上时间步信息    model_input = scheduler.scale_model_input(x, t)     # 预测噪声    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]     # 使用调度器计算更新后的样本应该是什么样子    scheduler_output = scheduler.step(noise_pred, t, x)     # 更新输入图像    x = scheduler_output.prev_sample     # 时不时看一下输入图像和预测的“去噪”图像    if i % 10 == 0 or i == len(scheduler.timesteps) - 1:        fig, axs = plt.subplots(1, 2, figsize=(12, 5))         grid = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0)        axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)        axs[0].set_title(f"Current x (step {i})")         pred_x0 = (            scheduler_output.pred_original_sample        )         grid = torchvision.utils.make_grid(pred_x0, nrow=4).           permute(1, 2, 0)        axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)        axs[1].set_title(f"Predicted denoised images (step {i})")        plt.show()

     第二步生成图像的采样器是DDPMScheduler,我们可以使用新的DDIMScheduler来代替DDPMScheduler看看image_pipe生成的效果是否有提升,代码如下:

image_pipe.scheduler = schedulerimages = image_pipe(num_inference_steps=40).imagesimages[0]

       上述介绍了生成人脸的扩散模型以及生成的效果,也介绍了更快的采样器DDIMScheduler,下面我们使用蝴蝶数据集来微调人脸生成扩散模型:

四、微调人脸生成扩散模型

4.1 加载蝴蝶数据集

dataset_name = "huggan/smithsonian_butterflies_subset"dataset = load_dataset(dataset_name, split="train")image_size = 256batch_size = 4preprocess = transforms.Compose(    [        transforms.Resize((image_size, image_size)),        transforms.RandomHorizontalFlip(),        transforms.ToTensor(),        transforms.Normalize([0.5], [0.5]),    ]) def transform(examples):    images = [preprocess(image.convert("RGB")) for image in         examples["image"]]    return {"images": images} dataset.set_transform(transform) train_dataloader = torch.utils.data.DataLoader(    dataset, batch_size=batch_size, shuffle=True)

输出4幅蝴蝶图像,便于观察

print("Previewing batch:")batch = next(iter(train_dataloader))grid = torchvision.utils.make_grid(batch["images"], nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

4.2 微调人脸生成扩散模型

num_epochs = 2lr = 1e-5grad_accumulation_steps = 2 optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr) losses = [] for epoch in range(num_epochs):    for step, batch in tqdm(enumerate(train_dataloader),       total=len(train_dataloader)):        clean_images = batch["images"].to(device)        # 随机生成一个噪声,稍后加到图像上        noise = torch.randn(clean_images.shape).to(clean_images.           device)        bs = clean_images.shape[0]         # 随机选取一个时间步        timesteps = torch.randint(            0,            image_pipe.scheduler.num_train_timesteps,            (bs,),            device=clean_images.device,        ).long()         # 根据选中的时间步和确定的幅值,在干净图像上添加噪声        # 此处为前向扩散过程        noisy_images = image_pipe.scheduler.add_noise(clean_images,            noise, timesteps)         # 使用“带噪”图像进行网络预测        noise_pred = image_pipe.unet(noisy_images, timesteps,            return_dict=False)[0]         # 对真正的噪声和预测的结果进行比较,注意这里是预测噪声        loss = F.mse_loss(            noise_pred, noise        )         # 保存损失值        losses.append(loss.item())         # 根据损失值更新梯度        loss.backward()         # 进行梯度累积,在累积到一定步数后更新模型参数        if (step + 1) % grad_accumulation_steps == 0:            optimizer.step()            optimizer.zero_grad()     print(        f"Epoch {epoch} average loss: {sum(losses[-len(train_           dataloader):])/len(train_dataloader)}"    ) # 画出损失曲线,效果如图所示plt.plot(losses) 

4.3 使用微调好的模型生成图像

x = torch.randn(8, 3, 256, 256).to(device)for i, t in tqdm(enumerate(scheduler.timesteps)):    model_input = scheduler.scale_model_input(x, t)    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]    x = scheduler.step(noise_pred, t, x).prev_samplegrid = torchvision.utils.make_grid(x, nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

从图中可以看出生成的图像有蝴蝶数据的风格。

4.4 保持微调好的扩散模型,并且上传到Huggingface Hub中

image_pipe.save_pretrained("my-finetuned-model")
from huggingface_hub import HfApi, ModelCard, create_repo, get_   full_repo_name# 配置Hugging Face Hub,上传文件model_name = "ddpm-celebahq-finetuned-butterflies-2epochs"  # 使用@param 脚本程序对上传到 # Hugging Face Hub的文件进行命名local_folder_name = "my-finetuned-model" # @param脚本程序生成的名字,# 你也可以通过 image_pipe.save_pretrained('savename')自行指定description = "Describe your model here" # @paramhub_model_id = get_full_repo_name(model_name)create_repo(hub_model_id)api = HfApi()api.upload_folder(     folder_path=f"{local_folder_name}/scheduler",path_in_repo="",                     repo_id=hub_model_id )api.upload_folder(     folder_path=f"{local_folder_name}/unet", path_in_repo="",      repo_id=hub_model_id )api.upload_file(     path_or_fileobj=f"{local_folder_name}/model_index.json",     path_in_repo="model_index.json",     repo_id=hub_model_id,) # 添加一个模型卡片,这一步虽然不是必需的,但可以给他人提供一些模型描述信息 content = f"""---license: mittags:- pytorch- diffusers- unconditional-image-generation- diffusion-models-class---# 用法from diffusers import DDPMPipelinepipeline = DDPMPipeline.from_pretrained(' {hub_model_id}')image = pipeline().images[0]image'''"""card = ModelCard(content)card.push_to_hub(hub_model_id)

微调Trick:

  • 设置合适的batch_size,batch_size要在不超过GPU显存的前提下,尽量大一些,这样可以提高GPU计算效果;如果特别小,可以采用梯度累积的方式来更新模型参数,达到和大batch_size类似的效果,也就是多运行几次loss.backward(),再调用optimizer.step()和optimizer.zero_grad();
  • 训练过程中,要时不时生成一些图像样本来观察模型性能;
  • 训练过程中,可以把损失值、生成的图像样本等信息记录在日志中,可以使用Weights and Biases、TensorBoard等工具;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/111559.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

yolov5和yolov7部署的研究

1.结论 onnx推理比torch快3倍, openvino比onnx快一丢丢。 | yolov7.pt 转 onnx python export.py --weights best_31.pt --grid --end2end --simplify --topk-all 10 --iou-thres 0.65 --conf-thres 0.65 --img-size 320 320 --max-wh 200可以看到yolov7的 onnx是包括nms…

【Unity小技巧】手戳一个简单易用的游戏UI框架(附源码)

文章目录 前言整套框架分为三大部分框架代码调用源码参考完结 前言 开发一款游戏美术成本是极其高昂的,以我们常见的宣传片CG为例,动辄就要成百上千万的价格,因此这种美术物料一般只会放在核心剧情节点,引爆舆论,做高…

MATLAB中符号变量的使用方法解析

简介 MATLAB中常常使用符号变量,这里定义符号变量的函数是syms 使用方法如下 syms x y z 其中,x、y、z 是符号变量,可以是任意字母、数字或下划线组合而成的字符串。 举例1: 代码 以下是一个简单的例子,演示如何…

WebSocket- 前端篇

官网代码 // 为了浏览器兼容websocketconst WebSocket window.WebSocket || window.MozWebSocket// 创建连接 this.socket new WebSocket(ws://xxx)// 连接成功this.socket.onopen (res)>{console.log(websocket 连接成功)this.socket.send(入参字段) // 传递的参数字段}…

强化自主可控,润开鸿发布基于RISC-V架构的开源鸿蒙终端新品

2023 RISC-V中国峰会于8月23日至25日在北京召开,峰会以“RISC-V生态共建”为主题,结合当下全球新形势,把握全球新时机,呈现RISC-V全球新观点、新趋势。本次大会邀请了RISC-V国际基金会、业界专家、企业代表及社区伙伴等共同探讨RISC-V发展趋势与机遇,吸引超过百余家业界企业、高…

出现ZooKeeper JMX enabled by default这种错误的解决方法

系列文章专栏 学习以来遇到的bug/问题专栏 文章目录 系列文章专栏 前言 一 问题描述 二 解决方法 2.1 可能的原因分析 2.2 小编的问题解决方法 First:检查/etc/profile里面zookeeper的环境变量配置 Second:检查 zookeeper/conf/zoo.cfg里面的d…

minikube mac 启动

系统信息如下 最开始使用的minikube是1.22.0版本,按照如下命令启动: minikube start --memory7851 --cpus4 --image-mirror-countrycn遇到了下面一些问题: 1、拉取coredns:v1.8.0镜像失败 Error response from daemon: manifest for regis…

Tensorflow调用训练好的yolov5模型进行推理

文章目录 1、安装TensorFlow-GPU版本1.2、验证是否安装正常 2、将训练好的pt文件转换成onnx文件2.2、什么是Onnx模型和Tensorflow模型2.1、将onnx文件转换成pb文件 1、安装TensorFlow-GPU版本 1、创建虚拟环境python3.8 conda create -n TF2.4 python3.82、进入虚拟环境 conda…

智安网络|探索物联网架构:构建连接物体与数字世界的桥梁

物联网是指通过互联网将各种物理设备与传感器连接在一起,实现相互通信和数据交换的网络系统。物联网架构是实现这一连接的基础和框架,它允许物体与数字世界之间的互动和协作。 一、物联网架构的概述 物联网架构是一种分层结构,它将物联网系…

python面试:使用cProfile剖析程序性能

我们需要安装tuna:pip install tuna 程序执行完毕后,我们会得到一个results.prof,在CMD中输入指令:“tuna results.prof”。 import time import cProfile import pstatsdef add(x, y):resulting_sum 0resulting_sum xresulti…

(Windows )本地连接远程服务器(Linux),免密码登录设置

在使用VScode连接远程服务器时,每次打开都要输入密码,以及使用ssh登录或其它方法登录,都要本地输入密码,这大大降低了使用感受,下面总结了免密码登录的方法,用起来巴适得很,起飞。 目录 PowerSh…

2024年java面试(四)--spring篇

文章目录 1.BeanFactory 和 FactoryBean 的区别2.BeanFactory和ApplicationContext有什么区别?3.RequestBody、RequestParam、ResponseBody4.cookie和session的区别5.Servlet的生命周期6.Jsp和Servlet的区别7.SpringMvc执行流程8.RequestMapping是怎么使用9.如果一个接口有多个…

爬虫逆向实战(二十七)--某某招标投标网站招标公告

一、数据接口分析 主页地址:某网站 1、抓包 通过抓包可以发现数据接口是page 2、判断是否有加密参数 请求参数是否加密? 通过查看“载荷”模块可以发现,请求参数是一整个密文 请求头是否加密? 无响应是否加密? 通…

Mac下使用Homebrew安装MySQL5.7

Mac下使用Homebrew安装MySQL5.7 1. 安装Homebrew & Oh-My-Zsh2. 查询软件信息3. 执行安装命令4. 开机启动5. 服务状态查询6. 初始化配置7. 登录测试7.1 终端登录7.2 客户端登录 参考 1. 安装Homebrew & Oh-My-Zsh mac下如何安装homebrew MacOS安装Homebrew与Oh-My-Zsh…

使用DataX对MySQL 8.1进行数据迁移

1. 环境准备 1.1 下载DataX 这里采用直接下载的方式:https://datax-opensource.oss-cn-hangzhou.aliyuncs.com/202308/datax.tar.gz,不过这个包是真的有点大。 1.2 安装Python Python下载地址:https://www.python.org/downloads/ 安装的时…

运维Shell脚本小试牛刀(一)

运维Shell脚本小试牛刀(一) 运维Shell脚本小试牛刀(二) 一: Shell中循环剖析 for 循环....... #!/bin/bash - # # # # FILE: countloop.sh # USAGE: ./countloop.sh # DESCRIPTION: # OPTIONS: ------- # …

泰迪大数据实训平台产品介绍

大数据产品包括:大数据实训管理平台、大数据开发实训平台、大数据编程实训平台等 大数据实训管理平台 泰迪大数据实训平台从课程管理、资源管理、实训管理等方面出发,主要解决现有实验室无法满足教学需求、传统教学流程和工具低效耗时和内部教学…

C++ 读写Excel LibXL库的使用附注册码(key)

LibXL是一款用于读写处理 Excel 文件的库,支持C, C++, C#,Python等语言。并且支持多个平台windows、Linux、Mac等,它提供了一系列的API,让开发人员可以方便地读取、修改和创建Excel文件。 一、关于库的key与使用 1.价值3000多的key 但是这个库并不是免费的,使用此库需要…

[Android AIDL] --- AIDL原理简析

上一篇文章已经讲述了如何在Android studio中搭建基于aidl的cs模型框架,只是用起来了,这次对aidl及cs端如何调用的原理进行简单分析 1 创建AIDL文件 AIDL 文件可以分为两类。 一类是用来定义接口方法,声明要暴露哪些接口给客户端调用&#…

Gitlab设置中文

1. 打开设置 2.选择首选项Preferences 3. 下滑选择本地化选项Localization,设置简体中文,然后保存更改save changes。刷新网页即可。