Hunuan-DiT代码阅读

一 整体架构

该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training

二 输入数据处理

这里主要包括图像和文本数据输入处理

2.1 图像处理

这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B310241024

2.2 文本处理

输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask

2.3 图像编码

对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)

2.4 文本编码

包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B771024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048

2.5 位置编码

这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 409688)

最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img

三 DIT模型

3.1 add noise过程

  • 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
    在这里插入图片描述

3.2 HunYuanDiT模型训练过程

  • 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B771024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B2561024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
  • 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B1408)
  • 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B4096140840966464
  • 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B1024
  • 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B1408,作为所有extra_vectors)

3.2.1 进入Dit Block

一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip

3.2.1.1 前面0到18共19个block
  • 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),c( B ∗ 1408 B*1408 B1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B409688
HunYuanDiTBlock((norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(attn1): FlashSelfMHAModified((Wqkv): Linear(in_features=1408, out_features=4224, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashSelfAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=1408, out_features=6144, bias=True)(act): GELU(approximate='tanh')(drop1): Dropout(p=0, inplace=False)(norm): Identity()(fc2): Linear(in_features=6144, out_features=1408, bias=True)(drop2): Dropout(p=0, inplace=False))(default_modulation): Sequential((0): FP32_SiLU()(1): Linear(in_features=1408, out_features=1408, bias=True))(attn2): FlashCrossMHAModified((q_proj): Linear(in_features=1408, out_features=1408, bias=True)(kv_proj): Linear(in_features=1024, out_features=2816, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashCrossAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
  • 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
  • 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
  • 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
  • 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
  • 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
  (skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)(skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
  • 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样

3.2.2 最后FInal layer处理

  • 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B2816),并彩分成shift 和 scale(分别为 B ∗ 1408 B*1408 B1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B409632),然后通过转换得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B8128128

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/444706.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

E系列I/O模块在锂电装备制造系统的应用

为了满足电池生产线对稳定性和生产效率的严苛要求,ZLG致远电子推出高速I/O应用方案,它不仅稳定可靠,而且速度快,能够迅速响应生产需求。 锂电池的生产工艺较为复杂,大致分为三个主要阶段:极片制作、电芯制作…

单点登录Apereo CAS 7.1客户端集成教程

从上一篇部署并成功运行CAS服务端后,我们已经能通过默认的账号密码进行登录。 上篇地址:单点登录Apereo CAS 7.1安装配置教程-CSDN博客 本篇我们将开始对客户端进行集成。 CAS中的客户端,就是指我们实际开发的各个需要登录认证的应用。现在,跟着笔者的步伐,一起探索如何…

springmvc直接访问 上下文路径 302 后路径更改并跳转源码解析

【问题现状】 application.yml 配置如下属性: server:servlet:context-path: /learning直接访问:http://localhost:8888/learning 路径时,会返回302的响应状态;并跳转路径:http://localhost:8888/learning/ (原路径后…

Docker Overlay2 空间优化

目录 分析优化数据路径规划日志大小限制overlay2 大小限制清理冗余数据 总结 分析 overlay2 目录占用磁盘空间较大的原因通常与 Docker 容器和镜像的存储机制以及它们的长期累积相关,其实我之前在 Docker 原理那里已经提到过了。 通常时以下几种原因导致&#xff…

☕️从小工到专家的 Java 进阶之旅:全新的HttpClient,现代高效的网络通信利器

你好,我是看山。 本文收录在 《从小工到专家的 Java 进阶之旅》 系列专栏。日拱一卒,功不唐捐。 在 Java 开发领域,网络通信一直是至关重要的部分。从早期的网络编程方式到如今,Java 在 HTTP 客户端方面经历了不断的演进。 其中&…

【C语言】函数栈帧的创建和销毁

文章目录 前言函数栈帧相关寄存器相关汇编指令内存函数栈帧的创建销毁过程 前言 为了更好的了解函数里面变量是如何创建,为什么创建的变量是随机值和函数怎么传参和顺序是怎样的、以及实参和形参的关系,还要函数之间的调用、返回和销毁的过程。我们今天…

Comfyui 学习笔记5

1.图像处理小工具,沿某个轴反转Image Flip 2. reactor换脸 3. 通过某人的多张照片进行训练 训练的模型会保存在 models/reactor/face/下面,使用时直接load就好 4. 为一个mask 更加模糊 羽化 5. 指定位置替换,个人感觉这种方式进行换脸的融…

Pura 70系列和Pocket 2已支持升级尝鲜鸿蒙NEXT,报名教程在这里

相信不少关注鸿蒙 NEXT 的人都知道,10月8日起,华为开启了鸿蒙 NEXT 系统的公测,但有不少人不知道的是,除了公测的 Mate 60 和 Mate X5 两个系列的机型,还有两个系列的手机其实也可以提前升级体验鸿蒙 NEXT 系统。 Pur…

从数据管理到功能优化:Vue+TS 项目实用技巧分享

引言 在项目开发过程中,优化用户界面和完善数据处理逻辑是提升用户体验的重要环节。本篇文章将带你一步步实现从修改项目图标、添加数据、优化日期显示,到新增自定义字段、调整按钮样式以及自定义按钮跳转等功能。这些操作不仅提升了项目的可视化效果&am…

统一流程引擎如何具体实现对多系统业务流程的整合?

在信息化时代,企业和组织通常会使用多个业务系统来满足不同的业务需求。然而,这些分散的业务系统往往会导致业务流程的碎片化,降低工作效率。统一流程引擎的出现为解决这一问题提供了有效的途径。它能够整合多系统的业务流程,实现…

LeetCode 3310. 移除可疑的方法

LeetCode 3310. 移除可疑的方法 你正在维护一个项目,该项目有 n 个方法,编号从 0 到 n - 1。 给你两个整数 n 和 k,以及一个二维整数数组 invocations,其中 invocations[i] [ai, bi] 表示方法 ai 调用了方法 bi。 已知如果方法 k…

云栖实录 | MaxCompute 迈向下一代的智能云数仓

本文根据2024云栖大会实录整理而成,演讲信息如下: 演讲人: 张治国 | 阿里云智能集团研究员、阿里云 MaxCompute 负责人 谢德军|阿里云智能集团资深技术专家 于得水|阿里云智能集团资深技术专家 谌鹏飞&#xff5c…

List子接口

1.特点:有序,有下标,元素可以重复 2.方法:包含Collection中的所有方法,还包括自己的独有的方法(API中查找) 还有ListIterator(迭代器),功能更强大。 包含更多…

Basic Pentesting_ 2靶机渗透

项目地址 plain https://download.vulnhub.com/basicpentesting/basic_pentesting_2.tar.gz 修改静态ip 开机按e 输入rw signie init/bin/bash ctrlx 进入编辑这个文件 vi /etc/network/interfaces修改网卡为ens33 保存退出 实验过程 开启靶机虚拟机 ![](https://img-bl…

C++ -内存管理

博客主页:【夜泉_ly】 本文专栏:【C】 欢迎点赞👍收藏⭐关注❤️ C -内存管理 C/C -内存管理的深入探讨1. 数据存储分类1.1 局部数据1.2 静态数据1.3 常量数据1.4 动态申请的数据 2. 内存区域划分2.1 栈区2.2 堆区2.3 静态区/数据段2.4 常量区…

HTML5--裸体回顾

免责声明:本文仅做分享~ 详情请参考以下: HTML 系列教程 (w3school.com.cn) 菜鸟教程 - 学的不仅是技术,更是梦想! --本文是光秃秃的空壳. 标题标签 段落标签 换行和水平线 文本格式化标签 (一般用左边的&#xff…

抽象工厂模式(Abstract Factory Pattern)

抽象工厂模式(Abstract Factory Pattern)是一种创建型设计模式,它能创建一系列相关的对象,而无需指定其具体类,另一种说法是围绕一个超级工厂创建其他工厂。该超级工厂又称为其他工厂的工厂。它提供了一种创建对象的最…

如何让信息学奥赛学习“边玩边学”?——趣味编程让枯燥学习变得有趣

信息学奥赛(NOI)作为一项高水平的编程竞赛,内容涉及到大量的算法、数据结构和复杂的逻辑思维,对学生的要求非常高。然而,面对枯燥的知识点和高难度的题目,很多学生在备赛过程中容易感到乏味甚至放弃。那么&…

YOLO11改进|SPPF篇|引入SPPFCSPC金字塔结构

目录 一、【SPPFCSPC】金字塔结构1.1【SPPFCSPC】金字塔结构介绍1.2【SPPFCSPC】核心代码 二、添加【SPPFCSPC】金字塔结构2.1STEP12.2STEP22.3STEP32.4STEP4 三、yaml文件与运行3.1yaml文件3.2运行成功截图 一、【SPPFCSPC】金字塔结构 1.1【SPPFCSPC】金字塔结构介绍 下图是…

重学SpringBoot3-集成Redis(一)之基础功能

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Redis(一)之基础功能 1. 项目初始化2. 配置 Redis3. 配置 Redis 序列化4. 操作 Redis 工具类5. 编写 REST 控制器6. 测试 AP…