[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; thenecho "-> preprocessed_11k or 212k"S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; thenS3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
elseecho "Something wrong in data folder name"exit
fiwget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if neededBATCH_SIZE=2
GRAD_ACCUMULATION=4StartTime=$(date +%s)CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \--pretrained_model_name_or_path $MODEL_NAME \--train_data_dir $TRAIN_DATA_DIR\--use_ema \--resolution 512 --center_crop --random_flip \--train_batch_size $BATCH_SIZE \--gradient_checkpointing \--mixed_precision="fp16" \--learning_rate 5e-05 \--max_grad_norm 1 \--lr_scheduler="constant" --lr_warmup_steps=0 \--report_to="all" \--max_train_steps=20 \--seed 1234 \--gradient_accumulation_steps $GRAD_ACCUMULATION \--checkpointing_steps 5 \--valid_steps 5 \--lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \--use_copy_weight_from_teacher \--unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \--output_dir $OUTPUT_DIREndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/206936.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Xilinx Zynq-7000系列FPGA实现视频拼接显示,提供两套工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐FPGA图像处理方案FPGA视频拼接叠加融合方案推荐 3、设计思路详解Video Mixer介绍 4、工程代码1:2路视频拼接 HDMI 输出PL 端 FPGA 逻辑设计PS 端 SDK 软件设计 5、工程代码2:4路视频拼接 HDMI 输出PL 端 FPGA 逻辑设…

maven 基础

maven常用命令 clean &#xff1a;清理 compile&#xff1a;编译 test&#xff1a;测试 package&#xff1a;打包 install&#xff1a;安装 maven坐标书写规范 <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</ar…

Javaweb之Vue组件库Element案例的详细解析

4.4.3.3 顶部标题 对于顶部&#xff0c;我们需要实现的效果如下图所示&#xff1a; 所以我们需要修改顶部的文本内容&#xff0c;并且提供背景色的css样式&#xff0c;具体代码如下&#xff1a; <el-header style"font-size:40px;background-color: rgb(238, 241, 24…

【腾讯云 HAI域探秘】借助高性能应用HAI——我也能使用【stable diffusion】制作高级视频封面了

目录 高性能应用服务HAI_GPU云服务器的申请与服务创建 官网地址&#xff1a;高性能应用服务HAI_GPU云服务器_腾讯云 通过高性能应用服务HAI——创建【stable diffusion】 WebUI效果&#xff1a; 服务器后台效果&#xff1a; stable-diffusion服务测试 启动接口服务 配置…

模拟算法【3】——1419.数青蛙

文章目录 &#x1f365;1. 题目&#x1f96e;2. 算法原理&#x1f361;3. 代码实现 &#x1f365;1. 题目 题目链接&#xff1a;1419. 数青蛙 - 力扣&#xff08;LeetCode&#xff09; 给你一个字符串 croakOfFrogs&#xff0c;它表示不同青蛙发出的蛙鸣声&#xff08;字符串 &…

17. Python 数据库操作之MySQL和SQLite实例

目录 1. 简介2. 使用PyMySQL2. 使用SQLite 1. 简介 数据库种类繁多&#xff0c;每种数据库的对外接口实现各不相同&#xff0c;为了方便对数据库进行统一的操作&#xff0c;大部分编程语言都提供了标准化的数据库接口&#xff0c;用户不需要了解每种数据的接口实现细节&#x…

Docker篇之docker部署harbor仓库

一、首先需要安装docker step1&#xff1a;安装docker #1、安装yun源 yum install -y yum-utils #2、配置yum源 yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo # 如果上面源不稳定的话&#xff0c;更换为下列的aliyun源 yu…

SpringBoot 整合 Neo4j 实战(头歌)

文章目录 第1关&#xff1a;认识 Spring DATA Neo4J任务描述相关知识Spring DATA Neo4J - 简介Spring JDBC / Spring ORM 模块的缺点&#xff1a;Spring 数据模块的优点&#xff1a;Spring 数据模块功能&#xff1a;Spring DATA Neo4j 模块的附加功能&#xff1a; Spring DATA …

Modbus RTU协议及modbus库函数使用

一、与Modbus TCP的区别 在一般工业场景使用modbus RTU的场景还是更多一些&#xff0c;modbus RTU基于串行协议进行收发数据&#xff0c;包括RS232/485等工业总线协议。 与modbus TCP不同的是RTU没有报文头MBAP字段&#xff0c;但是在尾部增加了两个CRC检验字节&#xff08;CRC…

【Web】UUCTF 2022 新生赛 个人复现

目录 ①websign ②ez_rce ③ez_upload ④ez_unser ⑤ezsql ⑥ezpop ⑦funmd5 ⑧phonecode ⑨ezrce ①websign 右键打不开&#xff0c;直接抓包发包看源码 ②ez_rce “反引号” 在PHP中会被当作SHELL命令执行 ?codeprintf(l\s /); ?codeprintf(ta\c /ffffffffffl…

特征变换1

编译工具&#xff1a;PyCharm 有些编译工具不用写print可以直接将数据打印出来&#xff0c;pycharm需要写print才会打印出来。 概念 1.特征类型 特征的类型&#xff1a;“离散型”和“连续型” 机器学习算法对特征的类型是有要求的&#xff0c;不是任意类型的特征都可以随意…

Spring RabbitMQ那些事(2-两种方式实现延时消息订阅)

目录 一、序言二、死信交换机和消息TTL实现延迟消息1、死信队列介绍2、代码示例(1) 死信交换机配置(2) 消息生产者(3) 消息消费者 3、测试用例 三、延迟消息交换机实现延迟消息1、安装延时消息插件2、代码示例(1) 延时消息交换机配置(2) 消息生产者(3) 消息消费者 3、测试用例 …

set和map + multiset和multimap(使用+封装(RBTree))

set和map 前言一、使用1. set(1)、模板参数列表(2)、常见构造(3)、find和count(4)、insert和erase(5)、iterator(6)、lower_bound和upper_bound 2. multiset3. map(1)、模板参数列表(2)、构造(3)、modifiers和operations(4)、operator[] 4. multimap 二、封装RBTree迭代器原理R…

科技与教育:未来教育的新趋势

在21世纪&#xff0c;科技的快速发展正在深刻地改变教育行业。从在线学习平台到虚拟现实教室&#xff0c;科技为教育带来了革命性的变化。本文将探讨科技如何影响现代教育&#xff0c;并预测未来教育的发展趋势。 一、科技在教育中的应用 在线学习平台&#xff1a;通过平台如C…

JSON 与 FastJSON

JSON 与 FastJSON JSON JavaScript Object Notation&#xff08;JavaScript 对象表示法&#xff09;是目前最常用的执行对象序列化的方式。 虽然 json 最初是为了在 JavaScript 语言中使用的&#xff0c;但实际上 json 本身跟语言没有任何关系&#xff0c;各种编程语言都可以使…

微服务--08--Seata XA模式 AT模式

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 分布式事务Seata 1.XA模式1.1.两阶段提交1.2.Seata的XA模型1.3.优缺点 AT模式2.1.Seata的AT模型2.2.流程梳理2.3.AT与XA的区别 分布式事务 > 事务–01—CAP理论…

Flutter使用flutter_gen管理资源文件

pub地址&#xff1a; https://pub.dev/packages/flutter_gen 1.添加依赖 在你的pubspec.yaml文件中添加flutter_gen作为开发依赖 dependencies:build_runner:flutter_gen_runner: 2.配置pubspec.yaml 在pubspec.yaml文件中&#xff0c;配置flutter_gen的参数。指定输出路…

msvcp140.dll的解决方法有哪些。详细解析五种可以修复msvcp140.dll丢失的方法

引言&#xff1a; 在日常使用电脑的过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中之一就是“msvcp140.dll丢失”。那么&#xff0c;什么是msvcp140.dll文件&#xff1f;它的作用是什么&#xff1f;当它丢失时会对电脑产生什么影响&#xff1f;本文将详细介绍…

使用elementPlus去除下拉框蓝色边框

// 下拉框去除蓝色边框 .el-select {--el-select-input-focus-border-color: none !important; }

仅仅通过提示词,GPT-4可以被引导成为多个领域的特定专家

The Power of Prompting&#xff1a;提示的力量&#xff0c;仅通过提示&#xff0c;GPT-4可以被引导成为多个领域的特定专家。微软研究院发布了一项研究&#xff0c;展示了在仅使用提策略的情况下让GPT 4在医学基准测试中表现得像一个专家。研究显示&#xff0c;GPT-4在相同的基…