如何魔改 diffusers 中的 pipelines

如何魔改 diffusers 中的 pipelines

整个 Stable Diffusion 及其 pipeline 长得就很适合 hack 的样子。不管是通过简单地调整采样过程中的一些参数,还是直接魔改 pipeline 内部甚至 UNet 内部的 Attention,都可以实现很多有趣的功能或采样生图结果。

本文主要介绍两种魔改 diffusers pipelines 的方式,一是通过注册 callback 函数,来在采样生图过程中执行某些操作,二是直接自己写 custom pipelines。

pipeline callbacks

可参考官方文档:Pipelines Callback

通过在 pipe 推理生图时传入自定义的回调函数,不用动底层代码,可以在每个时间步结束时动态地执行一些我们想要的动作,比如在特定的时间步修改特定的采样参数或张量。目前仅支持 callback_on_step_end 在单步结束时执行回调函数。我们通过两个例子来介绍如何通过 callback 函数来魔改 diffusers pipelines。

Dynamic classifier-free guidance

classifier-free guidance (cfg)用于通过 prompt 来引导图像生成的内容。在 diffusers 中,会同时用 CLIP 文本编码器同时编码 prompt 的 embeds 和空 prompt (空字符串或 negative prompt)的 embeds,然后拼接起来,一起通过交叉注意力与 UNet 交互。

通过 callback 函数,我们可以按照自己的需求动态控制 cfg,比如说,我想在特定的时间步之后停止使用 cfg,从而节省计算开销,并且性能不会有很大的损失。 callback 函数需要接收以下参数:

  • pipeline:通过 pipe 可以访问和编辑许多重要的采样参数,如 pipe.num_timestepspipe.guidance_scale 等。在本例中,我们就可以通过将 pipe._guidance_scale 来停用 cfg。

  • timestep 和 step_index:这两个参数可以让我们知道本次采样过程一共有多少时间步,以及当前我们位于哪一步。从而可以根据当前位于整个采样过程中的位置,来选择进行什么操作。在本例中,我们可以设置在整个采样过程 40% 及以后的位置,停用 cfg。

  • callback_kwargs:callback_kwargs 是一个字典,包含了在采样生图的过程中你可以编辑的张量。具体包含哪些张量,需要再调用 pipe 采样生图时通过 callback_on_step_end_tensor_inputs 参数传入。不同的 pipe 可能包含不同的可编辑张量,具体可以通过 pipe 的 _callback_tensor_inputs 属性查看。在本例中,我们需要在停用 cfg 之后调整 prompt_embeds 张量的批尺寸,丢掉空 prompt 部分。这是因为 sd pipe 是根据 _guidance_scale 的值来判断是否进行 cfg,所以我们将这个值改为零,就不会进行 cfg 了,所以需要将 prompt_embeds 不带 prompt 的部分丢掉,只保存带 prompt 的部分。

返回值方面,回调函数必须返回(修改好的) callback_kwargs。

最终,我们的 callback 函数是这样的:

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs, percent):if step_index == int(pipe.num_timesteps * percent):prompt_embeds = callback_kwargs['prompt_embeds']prompt_embeds = prompt_embeds.chunk(2)[-1]pipe._guidance_scale = 0.0callback_kwargs['prompt_embeds'] = prompt_embedsreturn callback_kwargs

然后在推理生图时传入该参数:

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")prompt = "a photo of an astronaut riding a horse on mars"generator = torch.Generator(device="cuda").manual_seed(1)
out = pipeline(prompt,generator=generator,callback_on_step_end=callback_dynamic_cfg,callback_on_step_end_tensor_inputs=['prompt_embeds']
)out.images[0].save("out_custom_cfg.png")
Display image after each generation step

在搭建生图 UI 时,一般需要支持用户在看到前几个时间步的结果不符合预期时,手动终止采样生图过程。这就需要两个功能,一是展示每一步的生图结果,二是支持在过程中终止本次生图。下面以展示中间步结果为例,介绍回调函数的使用。

我们知道,在 SD 中,去噪采样生图过程是发生在隐空间的,以 SDXL 为例,隐空间的特征图尺寸为 128 × 128 × 3 128\times 128\times 3 128×128×3 。我们需要将其转换到像素空间,才能看到这一步的实际生图结果。SDXL 还有点特殊,需要先将四通道的隐空间转换为 RGB 三通道,详情见 Explaining the SDXL latent space 。

def latents_to_rgb(latents):weights = ((60, -60, 25, -70),(60,  -5, 15, -50),(60,  10, -5, -35))weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()image_array = image_array.transpose(1, 2, 0)return Image.fromarray(image_array)def decode_tensors(pipe, step, timestep, callback_kwargs):latents = callback_kwargs["latents"]image = latents_to_rgb(latents)image.save(f"{step}.png")return callback_kwargs

然后再推理生图时传入该参数回调函数:

from diffusers import AutoPipelineForText2Image
import torch
from PIL import Imagepipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,variant="fp16",use_safetensors=True
).to("cuda")image = pipeline(prompt = "A croissant shaped like a cute bear.",negative_prompt = "Deformed, ugly, bad anatomy",callback_on_step_end=decode_tensors,callback_on_step_end_tensor_inputs=["latents"],
).images[0]

在这里插入图片描述

Custom Pipelines

可参考官方文档:Custome Pipelines、contribute-pipeline

在 diffusers 中,我们可以很方便的自定义并加载自己的定制化 pipelines。在实现自己的自定义 pipelines 时,需要继承基类 DiffusionPipeline

加载 custom pipelines

1 从 diffusers 仓库中加载自定义的 pipeline

从 hf hub 加载自定义的 pipeline 非常简单,只需要将 model_id 传入 custom_pipeline 参数,然后就会加载该仓库中对应的 pipeline.py。

from diffusers import DiffusionPipelinepipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)

2 从本地文件加载自定义的 pipeline

如果想从本地文件加载 pipeline,需要将 pipeline.py 所在的文件目录(注意是目录名)传给 custom_pipeline 参数。

from diffusers import DiffusionPipelinepipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="path/to/dir"
)

3 加载官方收录的社区 custom pipelines

这里是合入 diffuses 官方仓库的一些社区的自定义 pipelines。我们只需要将对应文件名(不含 py 后缀,如 clip_guided_stable_diffusion)传给 custom_pipeline 参数。

由于自定义 pipelines 的通常比较复杂,所以我们也可以通过官方 pipeline 来加载模型,再将模型传入自定义 pipelines。

from diffusers import DiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPModelclip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id)pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",custom_pipeline="clip_guided_stable_diffusion",clip_model=clip_model,feature_extractor=feature_extractor,
)
实现 custom pipelines

我们可以继承 DiffusionPipeline 基类并实现自己的 custom pipeline,这样所有人就都可以加载我们实现的 pipeline。一个 custom pipeline 的框架大致如下:

import torch
from diffusers import DiffusionPipelineclass MyPipeline(DiffusionPipeline):def __init__(self, unet, scheduler):super().__init__()self.register_modules(unet=unet, scheduler=scheduler)@torch.no_grad()def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):# Sample gaussian noise to begin loopimage = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))image = image.to(self.device)# set step valuesself.scheduler.set_timesteps(num_inference_steps)for t in self.progress_bar(self.scheduler.timesteps):# 1. predict noise model_outputmodel_output = self.unet(image, t).sample# 2. predict previous mean of image x_t-1 and add variance depending on eta# eta corresponds to η in paper and should be between [0, 1]# do x_t -> x_t-1image = self.scheduler.step(model_output, t, image, eta).prev_sampleimage = (image / 2 + 0.5).clamp(0, 1)image = image.cpu().permute(0, 2, 3, 1).numpy()return image

确保 xxxPipeline 这个类的相关实现都在这一个文件,而且该文件中只包含 xxxPipeline 一个类。因为 pipeline 的识别加载是自动的。

接下来,我们以一个最简单的 one-step pipeline 为例,简单介绍自己实现 custom pipeline 的过程。在这个one-step pipeline 中,只会用到 UNet 一个模型,并将 timestep 固定为 1,只进行一次模型前向。

首先新建一个 one_step_unet.py 文件,然后在其中继承 DiffusionPipeline 基类并实现 UnetSchedulerOneForwardPipeline 类。

初始化:定义 __init__ 方法

在初始化方法中,我们简单地的 one-step pipeline 只需要接收 unet 和 scheduler 两个初始化参数,我们在初始化方法中将变量定义好。注意,为了使得 save_pretrained 方法能够将我们的模型完整地保存下来,需要通过 register_modules 方法将我们想要保存的 unet 和 scheduler 注册进来。

from diffusers import DiffusionPipeline
import torchclass UnetSchedulerOneForwardPipeline(DiffusionPipeline):def __init__(self, unet, scheduler):super().__init__()self.register_modules(unet=unet, scheduler=scheduler)

前向推理:定义 __call__ 方法

定义好初始化方法 __init__ 之后,再来实现 pipeline 推理生图的 __call__ 方法。在这里,我们可以任意发挥,任意组合,魔改扩散模型的采样过程,实现自己想要的功能。在我们的 one-step pipeline 中,这里要做的非常简单:采样一个噪声图,UNet 进行一次前向。

@torch.no_grad()
def __call__(self):image = torch.randn(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.sanple_size)timestep = 1model_output = self.unet(image, timestep).samplescheduler_output = self.scheduler.step(model_output, timestep, image).prev_samplereturn scheduler_output

这就 ok 了,我们已经实现好了自定义的 one-step pipeline。

推理

我们传入 unet 和 scheduler,实例化一个刚刚自定义好的 UnetSchedulerOneForwardPipeline,然后进行推理生图:

from diffusers import DDPMScheduler, UNet2DModelscheduler = DDPMScheduler()
unet = UNet2DModel()pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)output = pipeline()

如果我们的 custom pipeline 结果如果跟某个已有的 pipeline 的预训练权重是完全一样的,我们还可以直接通过 from_pretrained 方法来加载它们的权重。比如说,我们的 UnetSchedulerOneForwardPipeline 就可以直接加载 google/ddpm-cifar10-32 的权重:

pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)output = pipeline()
分享 custom pipelines

要共享自己的 custom pipeline 有三个方法:

  1. 将自己实现的 custom pipeline 推送到 hf hub 仓库,需要将文件名命名为 pipeline.py

  2. 像 diffusers 官方仓库提交 PR,合并之后可以我们的 pipeline 就会出现在这里,别人可以通过文件名来加载

  3. 共享出自己的源码文件,如 clip_guided_stable_diffusion.py,别人也可以导入我们的 pipeline

而作为使用者,使用 custom pipeline 的方式有两种:

# 方式 1
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)# 方式 2
from cus_pipe import CustomPipeline    # cus_pipe is copied from hf-internal-testing/diffusers-dummy-pipeline
pipe = CustomPipeline.from_pretrained("google/ddpm-cifar10-32")

这两种使用方式应该是等价的,这可以从源码中看到:

if custom_pipeline is not None:pipeline_class = get_class_from_dynamic_module(custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline)
elif cls != DiffusionPipeline:pipeline_class = cls
else:diffusers_module = importlib.import_module(cls.__module__.split(".")[0])pipeline_class = getattr(diffusers_module, config_dict["_class_name"])

总结

diffusers 的 api 设计非常友好,我们可以通过 pipeline callback 和 custom pipeline 等方式定制化实现自己想要的功能,其中前者不用动底层代码,简单优雅,后者则是功能强大,现在最新的 AIGC 相关的论文基本都是通过 custom diffusion 的方式公开自己的源码,非常方便。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/300929.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从0到1搭建文档库——sphinx + git + read the docs

sphinx git read the docs 目录 一、sphinx 1 sphinx的安装 2 本地构建文件框架 1)创建基本框架(生成index.rst ;conf.py) conf.py默认内容 index.rst默认内容 2)生成页面(Windows系统下&#xf…

项目:自主实现Boost搜索引擎

文章目录 写在前面开源仓库和项目上线其他文档说明 项目背景项目的宏观原理技术栈与环境搜索引擎原理正排索引倒排索引 去标签和数据清洗模块html文件名路径保存函数html数据解析函数文件写入函数 建立索引模块检索和读取信息建立索引建立正排索引建立倒排索引jieba工具的使用倒…

基于FPGA的按键消抖

按键工作原理 当KEY1按下时,整条电路就会导通,这个时候KEY1就是低电平; 当KEY1松开时,整条电路就会断开,这个时候KEY1就是高定平; 我们可以通过判断KEY1的高低电平来判断按键是否被按下。 为什么按键消…

关于ansible的模块 ③

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 接《关于Ansible的模块①》和《关于Ansible的模块②》,继续学习ansible的user模块。 user模块可以增、删、改linux远…

idea 开发serlvet汽车租赁管理系统idea开发sqlserver数据库web结构计算机java编程layUI框架开发

一、源码特点 idea开发 java servlet 汽车租赁管理系统是一套完善的web设计系统sqlserver数据库 系统采用serlvetdaobean mvc 模式开发,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 java se…

移动平台相关(安卓)

目录 安卓开发 Unity打包安卓 ​编辑​编辑 BuildSettings PlayerSettings OtherSettings 身份证明 配置 脚本编译 优化 PublishingSettings 调试 ReMote Android Logcat AndroidStudio的调试 Java语法 ​编辑​编辑​编辑 变量 运算符 ​编辑​编辑​编辑​…

猫咪也爱美食!这款猫粮让你的猫主子欲罢不能!

👋 亲爱的宠友们,最近我也在为家里的猫主子换猫粮的事情犯愁呢!我知道,给猫咪选择一款合适的猫粮真的是个挺重要的事情。我家猫咪现在吃的是福派斯牛肉高脂猫粮,感觉还不错。 🐱 首先说说我家猫咪的情况吧。…

Java方法引用

方法引用概述 把已经有的东西拿过来用,当做函数式接口中的抽象方法的方法体 import java.util.*;public class test {public static void main(String[] args) {//需求:创建一个数组,进行倒序排列Integer[] arr {3,5,4,1,6,2};//匿名内部类Arrays.sort(arr, new Comparator&l…

某虚假交友APP(信息窃取)逆向分析

应用初探 在群里水群的时候 群u发了一个交友APP 于是拿来分析一下 可以看到应用打开后又一个登录的界面 需要用户输入手机号与验证码进行登录 #在线云沙箱分析 将APK放入某安信云沙箱中分析 提示应用请求了过多的敏感权限 逆向分析 直接拖入Jadx分析 好在程序没有加固 也没…

Vue 有哪些主要的指令修饰符

目录 1. 什么是指令修饰符 2. 指令修饰符有哪些 2.1. 按键修饰符 2.2. v-model修饰符 2.3. 事件修饰符 1. 什么是指令修饰符 通过 "." 指明一些指令 后缀,不同 后缀 封装了不同的处理操作 目的:简化代码 2. 指令修饰符有哪些 2.1. 按键…

【SQL Sever】3. 用户管理 / 权限管理

1. 创建登录名/用户/角色 在SQL Server中,创建用户通常涉及几个步骤。 首先,你需要创建一个登录名,然后你可以基于这个登录名在数据库中创建一个用户。 以下是如何做到这一点的步骤和相应的SQL语句: 创建登录名 首先&#xff0c…

docker使用arthas基本教程

供参考也是自己的笔记 docker容器下使用遇到的问题:大致是连接不上1号进程 我这边主要的问题是用户权限问题,docker容器使用aaa用户启动,那个在docker容器内,需要使用aaa用于启动 docker 容器如何使用arthas #实现下载好arthas …

C语言第四十一弹---猜数字游戏

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】 猜数字游戏 1、随机数生成 1.1、rand 1.2、srand 1.3、time 1.4、设置随机数的范围 2、猜数字游戏的分析和设计 2.1、猜数字游戏功能说明 2.2、猜数字游戏…

js笔记(学习存档)

JS的调用方式与执行顺序 使用方式 HTML页面中的任意位置加上<script type"module"></script>标签即可。 常见使用方式有以下几种&#xff1a; 直接在<script type"module"></script>标签内写JS代码。直接引入文件&#xff1a;…

DSOX3034T是德科技DSOX3034T示波器

181/2461/8938产品概述&#xff1a; 特点: 带宽:350 MHz频道:4存储深度:4 Mpts采样速率:5 GSa/s更新速率:每秒1000000个波形波形数学和FFT自动探测接口用于连接、存储设备和打印的USB主机和设备端口 触摸: 8.5英寸电容式触摸屏专为触摸界面设计 发现: 业界最快的无损波形更…

WPS快速将插入Excle数据插入Word

前置条件&#xff1a; 一张有标题、数据的excle表格word中的表格与excle表格标题对应或包含电脑已经安装WPS软件 第一步、根据word模板设计excle模板&#xff0c;标头对应 第二步、word上面选【引用】--【邮件】&#xff0c;选打开数据源&#xff0c;找到excle文件&#xff0c;…

Vue3与TypeScript中动态加载图片资源的解决之道

在前端开发中&#xff0c;Vue.js已成为一个备受欢迎的框架&#xff0c;尤其是在构建单页面应用时。Vue3的发布更是带来了许多性能优化和新特性&#xff0c;而TypeScript的加入则进一步提升了代码的可维护性和健壮性。然而&#xff0c;在实际的项目开发中&#xff0c;我们有时会…

手机软件何时统一--桥接模式

1.1 凭什么你的游戏我不能玩 2007年苹果手机尚未出世&#xff0c;机操作系统多种多样&#xff08;黑莓、塞班、Tizen等&#xff09;&#xff0c;互相封闭。而如今&#xff0c;存世的手机操作系统只剩下苹果OS和安卓&#xff0c;鸿蒙正在稳步进场。 1.2 紧耦合的程序演化 手机…

鸿蒙学习记录

问题小测记录 总结链接&#xff1a;小测总结 学习笔记&#xff1a;鸿蒙开发学习记录 1、 main_pages.json存放页面page路径配置信息。 2、在stage模型中&#xff0c;下列配置文件属于AppScope文件夹的是&#xff1f; app.json5 3、module.json5配置文件中&#xff0c;包含…

Django之REST Client插件

一、接口测试工具介绍 在开发前后端分离项目时,无论是开发后端,还是前端,基本都是需要测试API接口的内容,而目前我们需要开发遵循RESTFul规范的项目,也是必然的(自己不开发前端页面)。 在网上有很多这样的工具,常用的postman,但还是需要下载安装。在这我们介绍一个VSCod…