CSGO: Content-Style Composition in Text-to-Image Generation(代码的复现)

文章目录

  • CSGO简介
  • 论文的代码部署
    • 需要下载的模型权重:
    • 复现中存在的一些问题
  • 推理代码
  • 生成结果示意图

CSGO简介

CSGO: Content-Style Composition in Text-to-Image Generation(风格迁移)
本文是一篇风格迁移的论文:将内容参考图像和风格参考图像分别投影,然后注入到内容模块和风格模块,同时采用controlnet的方法将内容参考图像注入unet的上采样块当中。
在这里插入图片描述
github中项目的地址

论文的代码部署

需要下载的模型权重:

我们的方法与 SDXL、VAE、ControlNet 和图像编码器完全兼容。请下载它们并将它们放在 ./base_models 文件夹中。
按照readme里面的指引,下载到如下文件夹里面:

在这里插入图片描述

复现中存在的一些问题

①需要保证如下包的版本与readme一致

diffusers==0.25.1
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
transformers==4.40.2

② NotImplementedError: Cannot copy out of meta tensor; no data!
参考知乎这篇
大语言模型调用踩坑点记录
数据在显存和内存中切换,导致出问题(显存不够)
在这里插入图片描述
部分参数从gpu拷贝到cpu会报错,将改成low_cpu_mem_usage=False,可以正常推理

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(base_model_path,controlnet=controlnet,torch_dtype=torch.float16,add_watermarker=False,use_safetensors=True,vae=vae,revision="fp16",##这个参数low_cpu_mem_usage=False)

③模型加载的问题
由于下载的模型权重都是fp16的格式的,然而这里模型的加载方式的参数是统一在最外面控制的,导致不同模型加载时,识别不了对应的模型文件:
加载模型是一些参数的设定

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(base_model_path,controlnet=controlnet,torch_dtype=torch.float16,add_watermarker=False,#用safetensors格式的权重文件use_safetensors=True,vae=vae,revision="fp16",#device_map="auto"low_cpu_mem_usage=False)
#这两个参数同时为fp16才会去读fp16的文件revision="fp16"variant= "fp16"

④需要统一数据的数据类型
由于之前的文本编码器的权重读取的是fp32,导致后续出现数据的类型不相同不能做运算的情况。
将pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
add_watermarker=False,
use_safetensors=True,
vae=vae,
revision=“fp16”,
#device_map=“auto”
low_cpu_mem_usage=False

)中的.from_pretrained函数(这个函数在pipeline_utils.py文件夹里)进行修改,当模型是文本编码器时,修改传入的一些参数

 if name == "text_encoder":#如果是文本编码器,将varient设置为fp16variant = "fp16"if variant is not None:# for folder in os.listdir(cached_folder):folder_path = os.path.join(cached_folder, "text_encoder")is_folder = os.path.isdir(folder_path) and "text_encoder" in config_dictvariant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))if variant_exists:model_variants["text_encoder"] = variantif name == "text_encoder_2":variant = "fp16"if variant is not None:# for folder in os.listdir(cached_folder):folder_path = os.path.join(cached_folder, "text_encoder_2")is_folder = os.path.isdir(folder_path) and "text_encoder_2" in config_dictvariant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))if variant_exists:model_variants["text_encoder_2"] = variantloaded_sub_model = load_sub_model(library_name=library_name,class_name=class_name,importable_classes=importable_classes,pipelines=pipelines,is_pipeline_module=is_pipeline_module,pipeline_class=pipeline_class,torch_dtype=torch_dtype,provider=provider,sess_options=sess_options,device_map=device_map,max_memory=max_memory,offload_folder=offload_folder,offload_state_dict=offload_state_dict,model_variants=model_variants,name=name,from_flax=from_flax,variant=variant,low_cpu_mem_usage=low_cpu_mem_usage,cached_folder=cached_folder,revision=revision,)logger.info(f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}.")

⑤推理代码中的风格图像和内容图像都要是两个图像,而给的代码中是一个文本,一个图像

## 注意这里的两个图片都要转化为图片的格式,论文给的推理代码一个是文本,另一个是
style_image = Image.open("/mnt/CSGO-main/assets/{}".format(style_name)).convert('RGB')
content_image = Image.open('/mnt/test/image/{}'.format(content_name)).convert('RGB')

推理代码

import randomimport torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from PIL import Image
from diffusers import (AutoencoderKL,ControlNetModel,StableDiffusionXLControlNetPipeline,)
from ip_adapter import CSGO#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")device = 'cuda:0'base_model_path =  "./base_models/stable-diffusion-xl-base-1.0"  
image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "./CSGO/csgo4_32.bin"
pretrained_vae_name_or_path ='./base_models/vae'
controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16
weight_dtype = torch.float16vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)def get_device_map():return 'cuda' if torch.cuda.is_available() else 'cpu'device = get_device_map()pipe = StableDiffusionXLControlNetPipeline.from_pretrained(base_model_path,controlnet=controlnet,torch_dtype=torch.float16,add_watermarker=False,use_safetensors=True,vae=vae,revision="fp16",## 这里要加这个代码,不然会报错,因为显存不够,然后导致数据在显存和内存之间转换,报错low_cpu_mem_usage=False
)
pipe.enable_vae_tiling()target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet_adapter=True,controlnet_target_content_blocks=controlnet_target_content_blocks,controlnet_target_style_blocks=controlnet_target_style_blocks,content_model_resampler=True,style_model_resampler=True,)style_name = 'img_0.png'
content_name = 's_01_e_26_shot_005126_005200.png'## 注意这里的两个图片都要转化为图片的格式,论文给的推理代码一个是文本,另一个是
style_image = Image.open("/mnt/CSGO-main/assets/{}".format(style_name)).convert('RGB')
content_image = Image.open('/mnt/test/image/{}'.format(content_name)).convert('RGB')num_sample=1
caption = ''
#写个循环,看看各个参数对生成图片的影响
while True:tem = 0for ccs in range(5, 11, 1):ccs = ccs * 0.1content_scale = random.uniform(0.6, 1.5)style_scale = random.uniform(0.5, 1)images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,prompt=caption,negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",content_scale=1.0,style_scale=1.0,guidance_scale=10,num_images_per_prompt=num_sample,num_samples=1,num_inference_steps=50,seed=42,image=content_image.convert('RGB'),controlnet_conditioning_scale=0.6,)formatted_ccs = "{:.2f}".format(ccs)formatted_content_scale = "{:.2f}".format(content_scale)formatted_style_scale = "{:.2f}".format(style_scale)images[0].save(f"inference/ccs:{formatted_ccs}-cs:{formatted_content_scale}-ss:{formatted_style_scale}.png")tem = tem + 1if tem >= 100:break

生成结果示意图

风格参考图像
在这里插入图片描述

文本: a cat

生成的图像
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/461836.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安卓13默认连接wifi热点 android13默认连接wifi

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.编译6.彩蛋1.前言 有时候我们需要让固件里面内置好,相关的wifi的ssid和密码,让固件起来就可以连接wifi,不用在手动操作。 2.问题分析 这个功能,使用普通的安卓代码就可以实现了。 3.代…

C++ 复习记录(个人记录)

1、构造函数(constructor)是什么 答:类里面定义一个函数, 和类名一样, 这样在我们生成一个对象之后,就会默认调用这个函数,初始化这个类。 子类B继承父类A的情况, 当你调用子类的对…

Oasis 500M:开源的实时生成交互式视频内容的 AI 模型

❤️ 如果你也关注大模型与 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦! 🥦 微信公众号&#xff5c…

微服务实战系列之玩转Docker(十六)

导览 前言Q:基于容器云如何实现高可用的配置中心一、etcd入门1. 简介2. 特点 二、etcd实践1. 安装etcd镜像2. 创建etcd集群2.1 etcd-node12.2 etcd-node22.3 etcd-node3 3. 启动etcd集群 结语系列回顾 前言 Docker,一个宠儿,一个云原生领域的…

固定翼无人机飞行操控技术详解

固定翼无人机飞行操控技术是一个复杂而精密的领域,涵盖了从起飞准备到实际飞行操作,再到安全降落的各个环节。以下是对固定翼无人机飞行操控技术的详细解析: 一、起飞准备 1. 设备检查: 确保无人机充满电,检查电池状…

文件描述符fd 和 缓冲区

目录 1.文件描述符 fd 1.1文件打开的返回值fd(重点) 1.2.如何理解Linux下的一切皆文件 1.3.文件fd的分配原则 && 输出重定向 1.4.dup2()函数 2.缓冲区 2.1. 概念 2.2. 存在的原因 2.3. 类型(刷新方案) 2.4. 存放的位置 1.文件描述符 fd …

【qt qtcreator使用】【正点原子】嵌入式Qt5 C++开发视频

QT creator 的使用 一.qtcreator的介绍  (1).ui界面介绍    [1].软件左侧界面部分    [2].软件界面下方部分    [3].UI设计界面 (2).debug的使用 (3).项目的配置 (4).帮助文档的使用 (5).构建多个项目 二.qtcreator 的设置 (1).qt编译套件的设置 (2).设置快…

Vue3和Springboot前后端简单部署

一、Vue3Springboot 的前后端简单部署 (在win下面部署) 1、前端实现部署 思想: 前端打包项目后、放到nginx中进行部署 1、nginx 安装 和 解压 1、下载 nginx.zip win版本 解压就可以 2、解压后、启动程序 3、访问 nginx 欢迎页面 http://localhost/ 80 端口 可以省略 直接访…

【大数据学习 | kafka】kafka的ack和一致性

1. ack级别 上文中我们提到过kafka是存在确认应答机制的,也就是数据在发送到kafka的时候,kafka会回复一个确认信息,这个确认信息是存在等级的。 ack0 这个等级是最低的,这个级别中数据sender线程复制完毕数据默认kafka已经接收到…

【分布式技术】分布式事务深入理解

文章目录 概述产生原因关键点 分布式事务解决方案3PC3PC的三个阶段:3PC相比于2PC的改进:3PC的缺点: TCCTCC事务的三个阶段:TCC事务的设计原则:TCC事务的适用场景:TCC事务的优缺点:如何解决TCC模…

Linux高阶——1027—

1、守护进程的基本流程 1、父进程创建子进程,父进程退出 守护进程是孤儿进程,但是是工程师人为创建的孤儿进程,低开销模式运行,对系统没有压力 2、子进程(守护进程)脱离控制终端,创建新会话 …

centos7配置keepalive+lvs

拓扑图 用户访问www.abc.com解析到10.4.7.8,防火墙做DNAT将访问10.4.7.8:80的请求转换到VIP 172.16.10.7:80,负载均衡器再将请求转发到后端web服务器。 实验环境 VIP:负载均衡服务器的虚拟ip地址 LB :负载均衡服务器 realserv…

服务器宝塔安装哪吒监控

哪吒文档地址:https://nezha.wiki/guide/dashboard.html 一、准备工作 OAuth : 我使用的gitee,github偶尔无法访问,不是很方便。第一次用了极狐GitLab,没注意,结果是使用90天,90天后gg了,无法登…

【动手学强化学习】part6-策略梯度算法

阐述、总结【动手学强化学习】章节内容的学习情况,复现并理解代码。 文章目录 一、算法背景1.1 算法目标1.2 存在问题1.3 解决方法 二、REINFORCE算法2.1 必要说明softmax()函数交叉熵策略更新思想 2.2 伪代码算法流程简述 2.3 算法代码2.4 运行结果2.5 算法流程说明…

单片机内存管理和启动文件

一、常见存储器介绍 FLASH又称为闪存,不仅具备电子可擦除可编程(EEPROM)的性能,还不会断电丢失数据同时可以快速读取数据,U盘和MP3里用的就是这种存储器。在以前的嵌入式芯片中,存储设备一直使用ROM(EPROM),随着技术的…

Python画图3个小案例之“一起看流星雨”、“爱心跳动”、“烟花绚丽”

源码如下: import turtle # 导入turtle库,用于图形绘制 import random # 导入random库,生成随机数 import math # 导入math库,进行数学计算turtle.setup(1.0, 1.0) # 设置窗口大小为屏幕大小 turtle.title("流星雨动画&…

SQL-lab靶场less1-4

说明:部分内容来源于网络,如有侵权联系删除 前情提要:搭建sql-lab本地靶场的时候发现一些致命的报错: 这个程序只能在php 5.x上运行,在php 7及更高版本上,函数“mysql_query”和一些相关函数被删除&#xf…

AutoGLM:智谱AI的创新,让手机成为你的生活全能助手

目录 引言一、AutoGLM:开启AI的Phone Use时代二、技术核心:AI从“语言理解”到“执行操作”三、实际应用案例:AutoGLM的智能力量1. 智能生活管理🍎2. 社交网络的智能互动🍑3. 办公自动化🍒4. 电子商务的购物…

ceph补充介绍

SDS-ceph ceph介绍 crushmap 1、crush算法通过计算数据存储位置来确定如何存储和检索,授权客户端直接连接osd 2、对象通过算法被切分成数据片,分布在不同的osd上 3、提供很多种的bucket,最小的节点是osd # 结构 osd (or device) host #主…

Scrapy源码解析:DownloadHandlers设计与解析

1、源码解析 代码路径:scrapy/core/downloader/__init__.py 详细代码解析,请看代码注释 """Download handlers for different schemes"""import logging from typing import TYPE_CHECKING, Any, Callable, Dict, Gener…