代码解读:Diffusion Models中的长宽桶技术(Aspect Ratio Bucketing)

Diffusion Models专栏文章汇总:入门与实战

前言:自从SDXL提出了长宽桶技术之后,彻底解决了不同长宽比的图像输入问题,现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Ratio Bucketing)。

目录

原理解读-原有训练的问题

长宽桶技术(Aspect Ratio Bucketing)

完整代码


原理解读-原有训练的问题

纵横比分桶训练可以极大地提高输出质量,现有图像生成模型的一个常见问题是,它们非常容易生成带有非自然作物的图像。这是因为这些模型被训练成生成方形图像。然而,大多数照片和艺术品都不是方形的。然而,该模型只能同时在相同大小的图像上工作,并且在训练过程中,通常的做法是同时在多个训练样本上操作,以优化所使用gpu的效率。作为妥协,选择正方形图像,在训练过程中,只裁剪出每个图像的中心,然后作为训练样例显示给图像生成模型。

例如,人类通常是没有脚或头的,剑只有一个刀刃,剑柄和剑尖在框架外。因为我们正在创建一个图像生成模型来配合我们的故事叙述体验,所以我们的模型能够产生适当的,未裁剪的角色是很重要的,并且生成的骑士不应该持有延伸到无限的金属状直线。

对裁剪图像进行训练的另一个问题是,它可能导致文本和图像之间的不匹配。例如,带有王冠标签的图像通常在中央裁剪后不再包含王冠,因此君主已经被斩首。我们发现使用随机作物代替中心作物只能略微改善这些问题。使用具有可变图像大小的稳定扩散是可能的,尽管可以注意到,远远超过512x512的原生分辨率往往会引入重复的图像元素,并且非常低的分辨率会产生无法识别的图像。

尽管如此,这向我们表明,在可变大小的图像上训练模型应该是可能的。在单个、可变大小的样本上进行训练是微不足道的,但也非常缓慢,而且由于使用小批量提供的缺乏正则化,更容易产生训练不稳定性。

长宽桶技术(Aspect Ratio Bucketing)

由于这个问题似乎没有现有的解决方案,我们已经为我们的数据集实现了自定义批生成代码,允许创建批处理,其中批处理中的每个项目具有相同的大小,但批处理的图像大小可能不同。

我们通过一种叫做宽高比桶的方法来做到这一点。另一种方法是使用固定的图像大小,缩放每个图像以适应这个固定的大小,并应用在训练期间被掩盖的填充。由于这会导致训练期间不必要的计算,我们没有选择遵循这种替代方法。

在下面,我们描述了我们自定义的宽高比桶的批量生成方案背后的原始想法。

首先,我们必须定义要将数据集的图像排序到哪个存储桶中。为此,我们定义的最大图像尺寸为512x768,最大尺寸为1024。由于最大图像大小为512x768,比512x512大,需要更多的VRAM,因此每个gpu的批处理大小必须降低,这可以通过梯度积累来补偿。

我们通过应用以下算法生成桶:

Set the width to 256.
While the width is less than or equal to 1024:Find the largest height such that height is less than or equal to 1024 and that width multiplied by height is less than or equal to 512 * 768.Add the resolution given by height and width as a bucket.Increase the width by 64.

同样的重复,宽度和高度互换。重复的桶将从列表中删除,并添加一个大小为512x512的桶。

接下来,我们将图像分配到相应的桶中。为此,我们首先将桶分辨率存储在NumPy数组中,并计算每个分辨率的长宽比。对于数据集中的每张图像,我们检索其分辨率并计算长宽比。图像宽高比从桶宽高比数组中减去,使我们能够根据宽高比差的绝对值有效地选择最接近的桶:

image_bucket = argmin(abs(bucket_aspects — image_aspect))

图像的桶号与其数据集中的项目ID相关联。如果图像的宽高比非常极端,甚至与最适合的桶相差太大,则从数据集中修剪图像。

由于我们在多个GPU上进行训练,在每个epoch之前,我们对数据集进行了分片,以确保每个GPU在大小相等的不同子集上工作。为此,我们首先复制数据集中的项目id列表并对它们进行洗牌。如果这个复制的列表不能被gpu数量乘以批大小整除,则会对列表进行修剪,并删除最后的项以使其可整除。

然后,我们根据当前进程的全局排名选择1/world_size*bsz项id的不同子集。自定义批处理生成的其余部分将从这些过程中的任何一个过程中进行描述,并对数据集项id的子集进行操作。

对于当前的分片,每个bucket的列表是通过迭代打乱的数据集项目ID列表并将ID分配给分配给图像的bucket对应的列表来创建的。

处理完所有图像后,我们遍历每个bucket的列表。如果它的长度不能被批大小整除,则根据需要删除列表上的最后一个元素以使其可整除,并将它们添加到单独的捕获所有桶中。由于保证整个分片大小包含许多可被批大小整除的元素,因此保证生成一个长度可被批大小整除的所有bucket。

当请求批处理时,我们从加权分布中随机抽取一个桶。桶的权重设置为桶的大小除以所有剩余桶的大小。这确保了即使有大小差异很大的桶,自定义批生成在训练期间不会引入强烈的偏差,根据图像大小显示图像。如果在没有加权的情况下选择桶,那么小的桶将在训练过程中早期清空,只有最大的桶将在训练结束时保留。按大小对桶进行加权可以避免这种情况。

最后从所选的桶中取出一批项。取走的项目从桶中移除。如果桶现在为空,则在epoch的剩余时间内删除它。所选的项id和所选桶的分辨率现在被传递给图像加载函数。

完整代码

import numpy as np
import pickle
import timedef get_prng(seed):return np.random.RandomState(seed)class BucketManager:def __init__(self, bucket_file, valid_ids=None, max_size=(768,512), divisible=64, step_size=8, min_dim=256, base_res=(512,512), bsz=1, world_size=1, global_rank=0, max_ar_error=4, seed=42, dim_limit=1024, debug=False):with open(bucket_file, "rb") as fh:self.res_map = pickle.load(fh)if valid_ids is not None:new_res_map = {}valid_ids = set(valid_ids)for k, v in self.res_map.items():if k in valid_ids:new_res_map[k] = vself.res_map = new_res_mapself.max_size = max_sizeself.f = 8self.max_tokens = (max_size[0]/self.f) * (max_size[1]/self.f)self.div = divisibleself.min_dim = min_dimself.dim_limit = dim_limitself.base_res = base_resself.bsz = bszself.world_size = world_sizeself.global_rank = global_rankself.max_ar_error = max_ar_errorself.prng = get_prng(seed)epoch_seed = self.prng.tomaxint() % (2**32-1)self.epoch_prng = get_prng(epoch_seed) # separate prng for sharding use for increased thread resilienceself.epoch = Noneself.left_over = Noneself.batch_total = Noneself.batch_delivered = Noneself.debug = debugself.gen_buckets()self.assign_buckets()self.start_epoch()def gen_buckets(self):if self.debug:timer = time.perf_counter()resolutions = []aspects = []w = self.min_dimwhile (w/self.f) * (self.min_dim/self.f) <= self.max_tokens and w <= self.dim_limit:h = self.min_dimgot_base = Falsewhile (w/self.f) * ((h+self.div)/self.f) <= self.max_tokens and (h+self.div) <= self.dim_limit:if w == self.base_res[0] and h == self.base_res[1]:got_base = Trueh += self.divif (w != self.base_res[0] or h != self.base_res[1]) and got_base:resolutions.append(self.base_res)aspects.append(1)resolutions.append((w, h))aspects.append(float(w)/float(h))w += self.divh = self.min_dimwhile (h/self.f) * (self.min_dim/self.f) <= self.max_tokens and h <= self.dim_limit:w = self.min_dimgot_base = Falsewhile (h/self.f) * ((w+self.div)/self.f) <= self.max_tokens and (w+self.div) <= self.dim_limit:if w == self.base_res[0] and h == self.base_res[1]:got_base = Truew += self.divresolutions.append((w, h))aspects.append(float(w)/float(h))h += self.divres_map = {}for i, res in enumerate(resolutions):res_map[res] = aspects[i]self.resolutions = sorted(res_map.keys(), key=lambda x: x[0] * 4096 - x[1])self.aspects = np.array(list(map(lambda x: res_map[x], self.resolutions)))self.resolutions = np.array(self.resolutions)if self.debug:timer = time.perf_counter() - timerprint(f"resolutions:\n{self.resolutions}")print(f"aspects:\n{self.aspects}")print(f"gen_buckets: {timer:.5f}s")def assign_buckets(self):if self.debug:timer = time.perf_counter()self.buckets = {}self.aspect_errors = []skipped = 0skip_list = []for post_id in self.res_map.keys():w, h = self.res_map[post_id]aspect = float(w)/float(h)bucket_id = np.abs(self.aspects - aspect).argmin()if bucket_id not in self.buckets:self.buckets[bucket_id] = []error = abs(self.aspects[bucket_id] - aspect)if error < self.max_ar_error:self.buckets[bucket_id].append(post_id)if self.debug:self.aspect_errors.append(error)else:skipped += 1skip_list.append(post_id)for post_id in skip_list:del self.res_map[post_id]if self.debug:timer = time.perf_counter() - timerself.aspect_errors = np.array(self.aspect_errors)print(f"skipped images: {skipped}")print(f"aspect error: mean {self.aspect_errors.mean()}, median {np.median(self.aspect_errors)}, max {self.aspect_errors.max()}")for bucket_id in reversed(sorted(self.buckets.keys(), key=lambda b: len(self.buckets[b]))):print(f"bucket {bucket_id}: {self.resolutions[bucket_id]}, aspect {self.aspects[bucket_id]:.5f}, entries {len(self.buckets[bucket_id])}")print(f"assign_buckets: {timer:.5f}s")def start_epoch(self, world_size=None, global_rank=None):if self.debug:timer = time.perf_counter()if world_size is not None:self.world_size = world_sizeif global_rank is not None:self.global_rank = global_rank# select ids for this epoch/rankindex = np.array(sorted(list(self.res_map.keys())))index_len = index.shape[0]index = self.epoch_prng.permutation(index)index = index[:index_len - (index_len % (self.bsz * self.world_size))]#print("perm", self.global_rank, index[0:16])index = index[self.global_rank::self.world_size]self.batch_total = index.shape[0] // self.bszassert(index.shape[0] % self.bsz == 0)index = set(index)self.epoch = {}self.left_over = []self.batch_delivered = 0for bucket_id in sorted(self.buckets.keys()):if len(self.buckets[bucket_id]) > 0:self.epoch[bucket_id] = np.array([post_id for post_id in self.buckets[bucket_id] if post_id in index], dtype=np.int64)self.prng.shuffle(self.epoch[bucket_id])self.epoch[bucket_id] = list(self.epoch[bucket_id])overhang = len(self.epoch[bucket_id]) % self.bszif overhang != 0:self.left_over.extend(self.epoch[bucket_id][:overhang])self.epoch[bucket_id] = self.epoch[bucket_id][overhang:]if len(self.epoch[bucket_id]) == 0:del self.epoch[bucket_id]if self.debug:timer = time.perf_counter() - timercount = 0for bucket_id in self.epoch.keys():count += len(self.epoch[bucket_id])print(f"correct item count: {count == len(index)} ({count} of {len(index)})")print(f"start_epoch: {timer:.5f}s")def get_batch(self):if self.debug:timer = time.perf_counter()# check if no data left or no epoch initializedif self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered:self.start_epoch()found_batch = Falsebatch_data = Noneresolution = self.base_reswhile not found_batch:bucket_ids = list(self.epoch.keys())if len(self.left_over) >= self.bsz:bucket_probs = [len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_ids = [-1] + bucket_idselse:bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_probs = np.array(bucket_probs, dtype=np.float32)bucket_lens = bucket_probsbucket_probs = bucket_probs / bucket_probs.sum()bucket_ids = np.array(bucket_ids, dtype=np.int64)if bool(self.epoch):chosen_id = int(self.prng.choice(bucket_ids, 1, p=bucket_probs)[0])else:chosen_id = -1if chosen_id == -1:# using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square imageself.prng.shuffle(self.left_over)batch_data = self.left_over[:self.bsz]self.left_over = self.left_over[self.bsz:]found_batch = Trueelse:if len(self.epoch[chosen_id]) >= self.bsz:# return bucket batch and resolutionbatch_data = self.epoch[chosen_id][:self.bsz]self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:]resolution = tuple(self.resolutions[chosen_id])found_batch = Trueif len(self.epoch[chosen_id]) == 0:del self.epoch[chosen_id]else:# can't make a batch from this, not enough images. move them to leftovers and try againself.left_over.extend(self.epoch[chosen_id])del self.epoch[chosen_id]assert(found_batch or len(self.left_over) >= self.bsz or bool(self.epoch))if self.debug:timer = time.perf_counter() - timerprint(f"bucket probs: " + ", ".join(map(lambda x: f"{x:.2f}", list(bucket_probs*100))))print(f"chosen id: {chosen_id}")print(f"batch data: {batch_data}")print(f"resolution: {resolution}")print(f"get_batch: {timer:.5f}s")self.batch_delivered += 1return (batch_data, resolution)def generator(self):if self.batch_delivered >= self.batch_total:self.start_epoch()while self.batch_delivered < self.batch_total:yield self.get_batch()if __name__ == "__main__":# prepare a pickle with mapping of dataset IDs to resolutions called resolutions.pkl to use thiswith open("resolutions.pkl", "rb") as fh:ids = list(pickle.load(fh).keys())counts = np.zeros((len(ids),)).astype(np.int64)id_map = {}for i, post_id in enumerate(ids):id_map[post_id] = ibm = BucketManager("resolutions.pkl", debug=True, bsz=8, world_size=8, global_rank=3)print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))bm = BucketManager("resolutions.pkl", bsz=8, world_size=1, global_rank=0, valid_ids=ids[0:16])for _ in range(16):bm.get_batch()print("got from future epoch: " + str(bm.get_batch()))bms = []for rank in range(16):bm = BucketManager("resolutions.pkl", bsz=8, world_size=16, global_rank=rank)bms.append(bm)for epoch in range(5):print(f"epoch {epoch}")for i, bm in enumerate(bms):print(f"bm {i}")first = Truefor ids, res in bm.generator():if first and i == 0:#print(ids)first = Falsefor post_id in ids:counts[id_map[post_id]] += 1print(np.bincount(counts))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/380633.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UNiapp 微信小程序渐变不生效

开始用的一直是这个&#xff0c;调试一直没问题&#xff0c;但是重新启动就没生效&#xff0c;经查询这个不适合小程序使用&#xff1a;不适合没生效 background-image:linear-gradient(to right, #33f38d8a,#6dd5ed00); 正确使用下面这个&#xff1a; 生效&#xff0c;适合…

Python list comprehension (列表推导式 - 列表解析式 - 列表生成式)

Python list comprehension {列表推导式 - 列表解析式 - 列表生成式} 1. Python list comprehension (列表推导式 - 列表解析式 - 列表生成式)2. Example3. ExampleReferences Python 中的列表解析式并不是用来解决全新的问题&#xff0c;只是为解决已有问题提供新的语法。 列…

(10)深入理解pandas的核心数据结构:DataFrame高效数据清洗技巧

目录 前言1. DataFrame数据清洗1.1 处理缺失值&#xff08;NaNs&#xff09;1.1.1 数据准备1.1.2 读取数据1.1.3 查找具有 null 值或缺失值的行和列1.1.4 计算每列缺失值的总数1.1.5 删除包含 null 值或缺失值的行1.1.6 利用 .fillna&#xff08;&#xff09; 方法用Portfolio …

Windows搭建RTMP视频流服务器

参考了一篇文章&#xff0c;见文末。 博客中nginx下载地址失效&#xff0c;附上一个有效的地址&#xff1a; Index of /download/ 另外&#xff0c;在搭建过程中&#xff0c;遇到的问题总结如下&#xff1a; 1 两个压缩包下载解压并重命名后&#xff0c;需要 将nginx-rtmp…

如何使用简鹿水印助手或 Photoshop 给照片添加文字

在社交媒体中&#xff0c;为照片添加个性化的文字已经成为了一种流行趋势。无论是添加注释、引用名言还是表达情感&#xff0c;文字都能够为图片增添额外的意义和风格。本篇文章将使用“简鹿水印助手”和“Adobe Photoshop”这两种工具给照片添加文字的详细步骤。 使用简鹿水印…

【python基础】组合数据类型:元组、列表、集合、映射

文章目录 一. 序列类型1. 元组类型2. 列表类型&#xff08;list&#xff09;2.1. 列表创建2.2 列表操作2.3. 列表元素遍历 ing元素列表求平均值删除散的倍数 二. 集合类型&#xff08;set&#xff09;三. 映射类型&#xff08;map&#xff09;1. 字典创建2. 字典操作3. 字典遍历…

【EI检索】第二届机器视觉、图像处理与影像技术国际会议(MVIPIT 2024)

一、会议信息 大会官网&#xff1a;www.mvipit.org 官方邮箱&#xff1a;mvipit163.com 会议出版&#xff1a;IEEE CPS 出版 会议检索&#xff1a;EI & Scopus 检索 会议地点&#xff1a;河北张家口 会议时间&#xff1a;2024 年 9 月 13 日-9 月 15 日 二、征稿主题…

【香橙派开发板测试】:在黑科技Orange Pi AIpro部署YOLOv8深度学习纤维分割检测模型

文章目录 &#x1f680;&#x1f680;&#x1f680;前言一、1️⃣ Orange Pi AIpro开发板相关介绍1.1 &#x1f393; 核心配置1.2 ✨开发板接口详情图1.3 ⭐️开箱展示 二、2️⃣配置开发板详细教程2.1 &#x1f393; 烧录镜像系统2.2 ✨配置网络2.3 ⭐️使用SSH连接主板 三、…

Web开发:图片九宫格与非九宫格动态切换效果(HTML、CSS、JavaScript)

目录 一、业务需求 二、实现思路 三、实现过程 1、基础页面 2、图片大小调整 3、图片位置调整 4、鼠标控制切换 5、添加过渡 四、完整代码 一、业务需求 默认显示基础图片&#xff1b; 当鼠标移入&#xff0c;使用九宫格效果展示图片&#xff1b; 当鼠标离开&#…

CTF-Web习题:[BJDCTF2020]ZJCTF,不过如此

题目链接&#xff1a;[BJDCTF2020]ZJCTF&#xff0c;不过如此 解题思路 访问靶场链接&#xff0c;出现的是一段php源码&#xff0c;接下来做一下代码审阅&#xff0c;发现这是一道涉及文件包含的题 主要PHP代码语义&#xff1a; file_get_contents($text,r); 把$text变量所…

基于NeRF的路面重建算法——RoME / EMIE-MAP / RoGS

基于NeRF的路面重建算法——RoME / EMIE-MAP / RoGS 1. RoMe1.1 Mesh Initialization / Waypoint Sampling1.2 Optimization1.3 Experiments 2. EMIE-MAP2.1 Road Surface Representation based on Explicit mesh and Implicit Encoding2.2 Optimizing Strategies2.3 Experimen…

Uniapp鸿蒙项目实战

Uniapp鸿蒙项目实战 24.7.6 Dcloud发布了uniapp兼容鸿蒙的文档&#xff1a;Uniapp开发鸿蒙应用 在实际使用中发现一些问题&#xff0c;开贴记录一下 设备准备 windows电脑准备&#xff08;家庭版不行&#xff0c;教育版、企业版、专业版也可以&#xff0c;不像uniapp说的只有…

Promise 详解(原理篇)

目录 什么是 Promise 实现一个 Promise Promise 的声明 解决基本状态 添加 then 方法 解决异步实现 解决链式调用 完成 resolvePromise 函数 解决其他问题 添加 catch 方法 添加 finally 方法 添加 resolve、reject、race、all 等方法 如何验证我们的 Promise 是否…

分布式搜索之Elasticsearch入门

Elasticsearch 是什么 Elasticsearch 是一个分布式、RESTful 风格的搜索和数据分析引擎&#xff0c;能够解决不断涌现出的各种用例。作为 Elastic Stack 的核心&#xff0c;它集中存储您的数据&#xff0c;帮助您发现意料之中以及意料之外的情况。 Elastic Stack 又是什么呢&a…

企业须善用数字化杠杆经营获取数字化时代红利

​在当今数字化时代&#xff0c;企业面临着新机遇和新挑战。数字化技术的迅速发展正在重塑商业格局&#xff0c;企业若能善用数字化杠杆经营&#xff0c;将能够在激烈的市场竞争中脱颖而出&#xff0c;获取丰厚的时代红利。 数字化杠杆的内涵 数字化杠杆是指企业借助数字化技术…

SAPUI5基础知识16 - 深入理解MVC架构

1. 背景 经过一系列的练习&#xff0c;相信大家对于SAPUI5的应用程序已经有了直观的认识&#xff0c;我们在练习中介绍了视图、控制器、模型的概念和用法。在本篇博客中&#xff0c;让我们回顾总结下这些知识点&#xff0c;更深入地理解SAPUI5的MVC架构。 首先&#xff0c;让…

Android 性能优化之卡顿优化

文章目录 Android 性能优化之卡顿优化卡顿检测TraceView配置缺点 StricktMode配置违规代码 BlockCanary配置问题代码缺点 ANRANR原因ANRWatchDog监测解决方案 Android 性能优化之卡顿优化 卡顿检测 TraceViewStricktModelBlockCanary TraceView 配置 Debug.startMethodTra…

XMl基本操作

引言 使⽤Mybatis的注解⽅式&#xff0c;主要是来完成⼀些简单的增删改查功能. 如果需要实现复杂的SQL功能&#xff0c;建议使⽤XML来配置映射语句&#xff0c;也就是将SQL语句写在XML配置⽂件中. 之前&#xff0c;我们学习了&#xff0c;用注解的方式来实现MyBatis 接下来我们…

【STM32】按键控制LED光敏传感器控制蜂鸣器(江科大)

一、按键控制LED LED.c #include "stm32f10x.h" // Device header/*** 函 数&#xff1a;LED初始化* 参 数&#xff1a;无* 返 回 值&#xff1a;无*/ void LED_Init(void) {/*开启时钟*/RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA, ENAB…

数据结构(稀疏数组)

简介 稀疏数组是一种数据结构&#xff0c;用于有效地存储和处理那些大多数元素都是零或者重复值的数组。在稀疏数组中&#xff0c;只有非零或非重复的元素会被存储&#xff0c;从而节省内存空间。 案例引入 假如想把下面这张表存入文件&#xff0c;我们会怎么做&#xff1f;…