一 整体架构
该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training
二 输入数据处理
这里主要包括图像和文本数据输入处理
2.1 图像处理
这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B∗3∗1024∗1024)
2.2 文本处理
输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask
2.3 图像编码
对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B∗4∗128∗128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)
2.4 文本编码
包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B∗77∗1024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B∗256∗2048)
2.5 位置编码
这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 4096∗88)
最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img
三 DIT模型
3.1 add noise过程
- 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
3.2 HunYuanDiT模型训练过程
- 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B∗77∗1024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B∗256∗2048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B∗256∗1024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B∗333∗1024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B∗333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B∗333∗1024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
- 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B∗1408)
- 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B∗4096∗1408,4096是64∗64)
- 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B∗256∗2048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B∗1024)
- 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B∗1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B∗1408,作为所有extra_vectors)
3.2.1 进入Dit Block
一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip
3.2.1.1 前面0到18共19个block
- 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B∗4096∗1408),c( B ∗ 1408 B*1408 B∗1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B∗333∗1024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B∗4096∗88)
HunYuanDiTBlock((norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(attn1): FlashSelfMHAModified((Wqkv): Linear(in_features=1408, out_features=4224, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashSelfAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=1408, out_features=6144, bias=True)(act): GELU(approximate='tanh')(drop1): Dropout(p=0, inplace=False)(norm): Identity()(fc2): Linear(in_features=6144, out_features=1408, bias=True)(drop2): Dropout(p=0, inplace=False))(default_modulation): Sequential((0): FP32_SiLU()(1): Linear(in_features=1408, out_features=1408, bias=True))(attn2): FlashCrossMHAModified((q_proj): Linear(in_features=1408, out_features=1408, bias=True)(kv_proj): Linear(in_features=1024, out_features=2816, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashCrossAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
- 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B∗4096∗1408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
- 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
- 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
- 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
- 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
(skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)(skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
- 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样
3.2.2 最后FInal layer处理
- 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B∗2816),并彩分成shift 和 scale(分别为 B ∗ 1408 B*1408 B∗1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B∗4096∗32),然后通过转换得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B∗8∗128∗128)