超详细!DALL · E 文生图模型实践指南

最近需要用到 DALL·E的推断功能,在现有开源代码基础上发现还有几个问题需要注意,谨以此篇博客记录之。

我用的源码主要是 https://github.com/borisdayma/dalle-mini 仓库中的Inference pipeline.ipynb 文件。

在这里插入图片描述

运行环境:Ubuntu服务器

⚠️注意:本博客仅涉及 DALL · E 推断,不涉及训练过程。


目录

  • 一、环境配置
  • 二、模型下载
  • 三、程序转换
  • 四、程序运行
  • 五、BUG清除指南


一、环境配置

建议使用anaconda新建一个dalle环境,然后在该环境中进行相关配置,避免与环境中的其他库产生版本冲突。

使用下述命令新建名为dalle的环境:

conda create -n dalle python==3.8.0

在终端分别运行下述命令,安装所需的python库:

# 安装 dalle运行需要的依赖库(注意版本只能是0.3.25)# Required only for colab environments + GPU
pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 安装 dalle特定的库
pip install dalle-mini
# 安装 VQGAN
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

PS:如果由于网络连接问题无法通过pip命令下载VQGAN,就采取Plan-B:将仓库 https://github.com/patil-suraj/vqgan-jax 下载到服务器并解压,然后使用cd命令将当前目录到对应的仓库下载路径下,在终端运行python setup.py install安装VQGAN即可。


二、模型下载

由于网络连接问题,我采取「事先把模型下载到本地」的策略对模型进行直接调用,首先要明确的一点是,本项目中使用DALL · E 对图像进行编码,使用VQGAN对图像进行解码,所以我们需要分别下载DALL · E 和 VQGAN 两个模型。

DALL · E 模型下载地址:
mini版本:https://huggingface.co/dalle-mini/dalle-mini/tree/main
mega版本:https://huggingface.co/dalle-mini/dalle-mega/tree/main

VQGAN 模型下载地址:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

下载完毕后,将模型部署到服务器,注意保存路径。


三、程序转换

相较于ipynb文件,我个人更加喜欢操作py文件,所以对于给定的ipynb文件,首先使用命令jupyter nbconvert --to script Inference pipeline.ipynb 将其转为同名py文件,该文件的主要内容如下(不含CLIP排序部分),其中模型路径 DALLE_MODEL和VQGAN_REPO 已改为本地路径(就是第二步中两个模型的保存路径),可以看到文件的注释也比较详细。

# dalle-mini
DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"import jax
import jax.numpy as jnp# check how many devices are available
jax.local_device_count()# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)# Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):return model.generate(**tokenized_prompt,prng_key=key,params=params,top_k=top_k,top_p=top_p,temperature=temperature,condition_scale=condition_scale,)# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):return vqgan.decode_code(indices, params=params)# Keys are passed to the model on each device to generate unique inference per device.
import random# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)# ## 🖍 Text Prompt
# Our model requires processing prompts.from dalle_mini import DalleBartProcessor 
# from transformers import AutoProcessor
processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
# Let's define some text prompts
prompts = ["sunset over a lake in the mountains","the Eiffel tower landing on the moon",
]
# print(prompts)
# Note: we could use the same prompt multiple times for faster inference.
tokenized_prompts = processor(prompts)
# Finally we replicate the prompts onto each device.
tokenized_prompt = replicate(tokenized_prompts)# ## 🎨 We generate images using dalle-mini model and decode them with the VQGAN.# number of predictions per prompt
n_predictions = 8# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0  # 越高,生成的图像越接近 promptfrom flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trangeprint(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):# get a new keykey, subkey = jax.random.split(key)  #  jax.device_count()=1,returns the number of available jax devices# generate imagesencoded_images = p_generate(tokenized_prompt,shard_prng_key(subkey),params,gen_top_k,gen_top_p,temperature,cond_scale,)# remove BOSencoded_images = encoded_images.sequences[..., 1:]decoded_images = p_decode(encoded_images, vqgan_params)decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))for idx, decoded_img in enumerate(decoded_images):img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))images.append(img)
... 

四、程序运行

使用命令 python /newdata/SD/inference_dalle-mini.py 运行程序。理想情况下就能够直接得到dalle生成的图像啦!


五、BUG清除指南

由于外部环境因素和一些不当操作,本人在运行该程序过程中还是遇到一些问题,主要有三个,在此将抱错信息与解决方法一并分享给大家。

  • 因网络问题导致特定文件下载失败,报错信息如下:
...
requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words-frequency.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7faae4168460>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)')The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/newdata/SD/inference_dalle-mini.py", line 84, in <module>processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=TrueFile "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrainedreturn super(PretrainedFromWandbMixin, cls).from_pretrained(File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrainedreturn cls(tokenizer, config.normalize_text, config.max_text_length)File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__self.text_processor = TextNormalizer()File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__self._hashtag_processor = HashtagProcessor()File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__#     wiki_word_frequency = hf_hub_download(File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fnreturn fn(*args, **kwargs)File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_downloadraise LocalEntryNotFoundError(
huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

顺着上面的报错信息,定位到/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py文件的如下内容:

...
class HashtagProcessor:# Adapted from wordninja library# We use our wikipedia word count + a good heuristic to make it workdef __init__(self):wiki_word_frequency = hf_hub_download("dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt")self._word_cost = (l.split()[0]for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines())
...

于是问题的根源就在于,程序运行到这里时,没有找到本地的enwiki-words-frequency.txt文件(经检查该文件其实是存在本地的,不知为何没有找到,很迷),于是尝试通过联网从huggingface官网下载,但由于网络状况欠佳,联网失败,于是报错。解决办法如下:

...
class HashtagProcessor:# Adapted from wordninja library# We use our wikipedia word count + a good heuristic to make it workdef __init__(self):wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt"self._word_cost = (l.split()[0]for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines())
...

也就是将enwiki-words-frequency.txt文件的本地路径直接赋值给wiki_word_frequency变量,其余部份保持不变,问题解决。


  • 因安装不当导致的版本冲突问题
FIx for "Couldn't invoke ptxas --version"

这个错误的产生是不同python库安装时带来的版本冲突导致的,DALLE-mini要求jax和jaxlib版本必须为0.3.25,但是通过pip imstall dalle-mini 命令安装后的jaxlib版本为0.4.13,但使用pip install jaxlib的方式并不能找到0.3.25版本的jaxlib,而且会产生与flax、orbax-checkpoint等其他库的版本不兼容问题……在尝试多种方法合理降低jaxlib版本均失败后,发现答案就在ipynb中……也就是:pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

💡启示:要以官方说明文档为主,可以少走很多弯路!!!


  • 彩蛋:一个非常奇怪的错误:
The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/newdata/SD/inference_dalle-mini.py", line 130, in <module>decoded_images = p_decode(encoded_images, vqgan_params)
ValueError: pmap got inconsistent sizes for array axes to be mapped:* most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512];* some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512];* some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias'] of type float32[256];* some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128];* some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256];* one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]

后来发现,是因为之前调试的时候不小心把下面这行代码注释掉了……这个bug排得最辛苦,还挺无语的😂

vqgan_params = replicate(vqgan_params)

PS:程序运行过程中还有一些警告,由下述警告也可以看出jax是属于tensoeflow派别的。

2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']0%|          | 0/8 [00:00<?, ?it/s]
/root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.warnings.warn("scatter inputs have incompatible types: cannot safely cast "

后记:第一次接触到基于jax框架编写的程序,还挺新鲜的,感觉和pytorch有一些不一样的地方。了解到jax是tensorflow的轻量级版本。上述博客内容中如果有个人理解不当之处,还望各位批评指正!

参考链接

  1. python pathlib中Path 的使用(解决不同操作系统的路径问题)_python pathlib.path-CSDN博客
  2. python - vmap gives inconsistent shape error when trying to calculate gradient per sample - Stack Overflow
  3. https://github.com/google/jax/issues/9933

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/183766.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MVC、MVP、MVVM区别

MVC、MVP、MVVM区别 MVC&#xff08;Model-View-Controller&#xff09; 。是一种设计模式&#xff0c;通常用于组织与应用程序的数据流。它通常包括三个组件&#xff1a;模型&#xff08;Model&#xff09;、视图&#xff08;View&#xff09;和控制器&#xff08;Controller&…

JWT原理分析——JWT

了解为什么会有JWT的出现&#xff1f; 首先不得不提到一个知识叫做跨域身份验证&#xff0c;JWT的出现就是为了更好的解决这个问题&#xff0c;但是在没有JWT的时候&#xff0c;我们一般怎么做呢&#xff1f;一般使用Cookie和Session&#xff0c;流程大体如下所示&#xff1a;…

基于ssm在线考试系统设计与实现(代码+文档+数据库)

ssm在线考试系统&#xff0c;java在线考试系统&#xff0c;在线考试系统 运行环境&#xff1a; JAVA版本&#xff1a;JDK1.8 IDE类型&#xff1a;IDEA、Eclipse都可运行 数据库类型&#xff1a;MySql&#xff08;8.x版本都可&#xff09; 硬件环境&#xff1a;Windows 角色&…

【GitLab CI/CD、SpringBoot、Docker】GitLab CI/CD 部署SpringBoot应用,部署方式Docker

介绍 本文件主要介绍如何将SpringBoot应用使用Docker方式部署&#xff0c;并用Gitlab CI/CD进行构建和部署。 环境准备 已安装Gitlab仓库已安装Gitlab Runner&#xff0c;并已注册到Gitlab和已实现基础的CI/CD使用创建Docker Hub仓库&#xff0c;教程中使用的是阿里云的Docker…

腾讯云16核服务器配置有哪些?CPU型号处理器主频性能

腾讯云16核服务器配置大全&#xff0c;CVM云服务器可选择标准型S6、标准型SA3、计算型C6或标准型S5等&#xff0c;目前标准型S5云服务器有优惠活动&#xff0c;性价比高&#xff0c;计算型C6云服务器16核性能更高&#xff0c;轻量16核32G28M带宽优惠价3468元15个月&#xff0c;…

后端接口接收对象和文件集合,formdata传递数组对象

0 问题 后端接口需要接收前端传递过来的对象和文件集合&#xff1b;对象中存在数组对象 1 前端和后端 前端只能使用formdata来传递参数&#xff0c;后端不使用RequestBody注解 2 formdata传递数组对象 2.1 多个参数对象数组 addForm: {contactInfo: [{contactPerson: ,…

Apifox日常使用(一键本地联调)

背景说明&#xff1a;现在的项目一般都是前后分离&#xff0c;线上出bug或者在进行联调时&#xff0c;有些时候后端需要重复模拟前端数据格式&#xff0c;在使用Apifox的情况下&#xff0c;如何快速造出后端需要的数据呢&#xff1f; 随便找一个网站&#xff0c;点开f12&#…

FastGPT | 3分钟构建属于自己的AI智能助手

这是一篇使用指南&#xff01;&#xff01;&#xff01; FastGPT是什么&#xff1f; FastGPT 是一个基于 LLM 大语言模型的知识库问答系统&#xff0c;提供开箱即用的数据处理、模型调用等能力。同时可以通过 Flow 可视化进行工作流编排&#xff0c;从而实现复杂的问答场景&…

Solidity快速入门之函数输出

返回值return和returns Solidity有两个关键字与函数输出相关&#xff1a;return和returns&#xff0c;他们的区别在于&#xff1a; returns加在函数名后面&#xff0c;用于声明返回的变量类型及变量名&#xff1b;return用于函数主体中&#xff0c;返回想要返回的变量&#x…

Java入门篇 之 类与对象

本篇碎碎念&#xff1a;博主作为一个三本学生&#xff0c;庆幸自己上了个本科&#xff0c;但是在支付高昂学费的时候认识到&#xff0c;自己要好好学习&#xff0c;不好好学习&#xff0c;难道以后给人端盘子咩&#xff1b;无论是专科还是本科&#xff0c;都不可以自暴自弃&…

亚马逊云服务器成为了我的首选服务器

背景 作为一名计算机专业的大学生 当完成了自己的前后端项目或者是做出了属于自己的 网址&#xff0c;购买服务器是必不可少的亚马逊云服务器 相比于其他华为云 阿里云 以及腾讯云 等等 有着自己独特的优势 价格原因 学生党最在意的往往还是价格 一个良心亲民的价格 往往可以…

伐木猪小游戏

欢迎来到程序小院 伐木猪 玩法&#xff1a;控制小猪点击屏幕左右砍树&#xff0c;不能碰到树枝&#xff0c;考验手速与眼力&#xff0c;记录分数&#xff0c;快去挑战伐木吧^^。开始游戏https://www.ormcc.com/play/gameStart/199 html <script type"text/javascript…

《微服务架构设计模式》之三:微服务架构中的进程通信

概述 交互方式 客户端和服务端交互方式可以从两个维度来分&#xff1a; 维度1&#xff1a;一对一和多对多 一对一&#xff1a;每个客户端请求由一个实例来处理。 一对多&#xff1a;每个客户端请求由多个实例来处理。维度2&#xff1a;同步和异步 同步模式&#xff1a;客户端…

【Linux】JREE项目部署与发布

目录 一.jdk安装配置 1.1.传入资源 1.2. 解压 1.3. 配置 二.Tomcat安装 2.1.解压开启 2.2. 开放端口 三.MySQL安装 3.1.解压安装 3.2.登入配置 四.后端部署 今天就到这里了哦&#xff01;&#xff01;希望能帮到你哦&#xff01;&#xff01;&#xff01; 一.jdk…

VR全景在旅游中应用有哪些?VR云游的优势是什么?

近日受到剧烈日冕物质抛射活动影响&#xff0c;漠河再现极光美景&#xff0c;极光舞动的灿烂星空下&#xff0c;正在封冻的黑龙江上&#xff0c;无数的冰排随波而去&#xff0c;天地之间光影流动好不美丽。相信很多人都想了解、观赏祖国的大好风光&#xff0c;但是碍于没时间、…

【遍历二叉树算法描述】

文章目录 遍历二叉树算法描述先序遍历二叉树的操作定义中序遍历二叉树的操作定义后序遍历二叉树的操作定义 遍历二叉树算法描述 1.遍历定义&#xff1a;顺着某一条搜索路径寻访二叉树中的结点&#xff0c;使得每一个结点均被访问一次&#xff0c;而且仅访问一次&#xff08;又…

【Qt之绘制兔纸】

效果 代码 class drawRabbit: public QWidget { public:drawRabbit(QWidget *parent nullptr) : QWidget(parent) {}private:void paintEvent(QPaintEvent *event) {QPainter painter(this);painter.setRenderHint(QPainter::Antialiasing, true);// 绘制兔子的耳朵painter.s…

C语言-指针讲解(2)

文章目录 1.野指针1.1 什么是野指针1.2 造成野指针的原因有哪些呢1.2.1造成野指针具体代码实例&#xff1a; 1.3 如何避免野指针呢?1.3.1如何对指针进行初始化&#xff1f;1.3.2如何才能小心指针越界?1.3.3 指针变量不再使用时&#xff0c;如何及时置NULL,在指针使用之前检查…

单链表的实现

单链表的实现 单链表的链表的概念及结构概念结构链表结构的分类链表常用的结构 无头单向不循环链表头文件 SList.h结构体 struct SListNode 源文件 SList.c创建结点 SLNode* SLBuyNode(SLDataType x)初始化链表 void SLInit(SLNode** pphead)链表尾部插入 void SLPushBack(SLNo…

C语言:计算 1! + 2! + 3! + ... + n!

题目&#xff1a; 从键盘输入一个值n&#xff0c;计算 1的阶乘 至 n的阶乘 的和&#xff0c; 如&#xff1a;输入10&#xff0c;计算 1的阶乘 至 n的阶乘 的和 --> 计算&#xff1a;1! 2! 3! ... 10! 思路一&#xff1a; 效率比较低&#xff0c;会重复计算之前计算过的…