Open-Sora代码详细解读(1):解读DiT结构

Diffusion Models专栏文章汇总:入门与实战

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

DiT相比于Unet的关键改进点

Token化方法

因果3D卷积

Adaptive Layer Norm (adaLN) block 

完整DiT Block 设计


DiT相比于Unet的关键改进点

虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。

采用DiT架构替换UNet主要需要探索以下几个关键问题:

  1. Token化处理。Transformer的输入为一维序列,形式为𝑅𝑇×𝑑RT×d(忽略batch维度),而LDM的latent表征𝑧∈𝑅𝐻𝑓×𝑊𝑓×𝐶z∈RfH​×fW​×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。
  2. 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代Unet需要系统研究Transformer架构的条件嵌入

Token化方法

假定原始图片𝑥∈𝑅256×256×3,经过auto-encoder后得到latent表征𝑧∈𝑅32×32×4。首先DiT 用ViT中patch化的方式将隐表征𝑧转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p是一个超参数。

刚才是DiT原始论文的描述,在视频里用了一个PatchEmbed3D 执行Token化:

class PatchEmbed3D(nn.Module):"""Video to Patch Embedding.Args:patch_size (int): Patch token size. Default: (2,4,4).in_chans (int): Number of input video channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self,patch_size=(2, 4, 4),in_chans=3,embed_dim=96,norm_layer=None,flatten=True,):super().__init__()self.patch_size = patch_sizeself.flatten = flattenself.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):"""Forward function."""# padding_, _, D, H, W = x.size()if W % self.patch_size[2] != 0:x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))if H % self.patch_size[1] != 0:x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))if D % self.patch_size[0] != 0:x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))x = self.proj(x)  # (B C T H W)if self.norm is not None:D, Wh, Ww = x.size(2), x.size(3), x.size(4)x = x.flatten(2).transpose(1, 2)x = self.norm(x)x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)if self.flatten:x = x.flatten(2).transpose(1, 2)  # BCTHW -> BNCreturn x

先把视频的长宽和时间长都填充成偶数,然后用一个3D卷积,把时间、空间都进一步压缩,Channel从4膨胀到96,然后把时空都压缩到一起,即:

x = x.flatten(2).transpose(1, 2)  # BCTHW -> BNC

因果3D卷积

刚才Token化用的是普通的3D卷积,其他有些代码里用了因果3D卷积,因果3D卷积在视频任务里非常常用:

因果3D卷积(Causal 3D Convolution)是一种特殊的3D卷积,它在处理具有时间维度的数据(如视频)时保持因果性。这意味着在生成当前时间点的输出时,它只依赖于当前和之前的时间点,而不依赖于未来的时间点。卷积核在时间维度上滑动,它也只会接触到当前和过去的帧。这在序列建模和时间序列预测等任务中非常重要,因为它们需要保证模型输出的因果关系。

与传统的3D卷积相比,因果3D卷积在时间维度上增加了填充(padding),以确保输出的时间长度与输入相同。这种填充通常是在时间维度的开始处添加,而不是在两端添加,这样可以保证在预测当前帧时不会使用到后续帧的信息。通过在时间轴的正方向上(即未来的方向)添加适当的零填充来实现这一点。

下面是EasyAnimate的实现代码:

class CausalConv3d(nn.Conv3d):def __init__(self,in_channels: int,out_channels: int,kernel_size=3, # : int | tuple[int, int, int], stride=1, # : int | tuple[int, int, int] = 1,padding=1, # : int | tuple[int, int, int],  # TODO: change it to 0.dilation=1, # :  int | tuple[int, int, int] = 1,**kwargs,):kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."stride = stride if isinstance(stride, tuple) else (stride,) * 3assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."t_ks, h_ks, w_ks = kernel_size_, h_stride, w_stride = stridet_dilation, h_dilation, w_dilation = dilationt_pad = (t_ks - 1) * t_dilation# TODO: align with SDif padding is None:h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)elif isinstance(padding, int):h_pad = w_pad = paddingelse:assert NotImplementedErrorself.temporal_padding = t_padself.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)self.padding_flag = 0super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,dilation=dilation,padding=(0, h_pad, w_pad),**kwargs,)def forward(self, x: torch.Tensor) -> torch.Tensor:# x: (B, C, T, H, W)if self.padding_flag == 0:x = F.pad(x,pad=(0, 0, 0, 0, self.temporal_padding, 0),mode="replicate",     # TODO: check if this is necessary)else:x = F.pad(x,pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),)return super().forward(x)def set_padding_one_frame(self):def _set_padding_one_frame(name, module):if hasattr(module, 'padding_flag'):print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))module.padding_flag = 1for sub_name, sub_mod in module.named_children():_set_padding_one_frame(sub_name, sub_mod)for name, module in self.named_children():_set_padding_one_frame(name, module)def set_padding_more_frame(self):def _set_padding_more_frame(name, module):if hasattr(module, 'padding_flag'):print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))module.padding_flag = 2for sub_name, sub_mod in module.named_children():_set_padding_more_frame(sub_name, sub_mod)for name, module in self.named_children():_set_padding_more_frame(name, module)

Adaptive Layer Norm (adaLN) block 

这是DiT里面最核心的设计之一,adaptive normalization layer(adaLN),将transformer block的layer norm替换为adaLN。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter 𝛾和shift parameter 𝛽 用condition embedding来替代。

原始的Layer Norm设计:

class LayerNorm:def __init__(self, feature_dim, epsilon=1e-6):self.epsilon = epsilonself.gamma = np.random.rand(feature_dim)  # scale parametersself.beta = np.random.rand(feature_dim)  # shift parametrsdef __call__(self, x: np.ndarray) -> np.ndarray:"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)return:x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""_mean = np.mean(x, axis=-1, keepdims=True)_std = np.var(x, axis=-1, keepdims=True)x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.betareturn x_layer_norm

DiT中的adaLN设计:

class DiTBlock(nn.Module):"""A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""def __init__(self,hidden_size,num_heads,mlp_ratio=4.0,enable_flash_attn=False,enable_layernorm_kernel=False,):super().__init__()self.hidden_size = hidden_sizeself.num_heads = num_headsself.enable_flash_attn = enable_flash_attnmlp_hidden_dim = int(hidden_size * mlp_ratio)self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)self.attn = Attention(hidden_size,num_heads=num_heads,qkv_bias=True,enable_flash_attn=enable_flash_attn,)self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))def forward(self, x, c):shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa))x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp))return x

完整DiT Block 设计

好了,到这里已经是把主要的DiT构建出来了,接下来把DiT结构堆积28层,构成了现在的DiT结构:

@MODELS.register_module()
class DiT(nn.Module):"""Diffusion model with a Transformer backbone."""def __init__(self,input_size=(16, 32, 32),in_channels=4,patch_size=(1, 2, 2),hidden_size=1152,depth=28,num_heads=16,mlp_ratio=4.0,class_dropout_prob=0.1,learn_sigma=True,condition="text",no_temporal_pos_emb=False,caption_channels=512,model_max_length=77,dtype=torch.float32,enable_flash_attn=False,enable_layernorm_kernel=False,enable_sequence_parallelism=False,):super().__init__()self.learn_sigma = learn_sigmaself.in_channels = in_channelsself.out_channels = in_channels * 2 if learn_sigma else in_channelsself.hidden_size = hidden_sizeself.patch_size = patch_sizeself.input_size = input_sizenum_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])self.num_patches = num_patchesself.num_temporal = input_size[0] // patch_size[0]self.num_spatial = num_patches // self.num_temporalself.num_heads = num_headsself.dtype = dtypeself.use_text_encoder = not condition.startswith("label")if enable_flash_attn:assert dtype in [torch.float16,torch.bfloat16,], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"self.no_temporal_pos_emb = no_temporal_pos_embself.mlp_ratio = mlp_ratioself.depth = depthassert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size)if not self.use_text_encoder:num_classes = int(condition.split("_")[-1])self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)else:self.y_embedder = CaptionEmbedder(in_channels=caption_channels,hidden_size=hidden_size,uncond_prob=class_dropout_prob,act_layer=approx_gelu,token_num=1,  # pooled token)self.t_embedder = TimestepEmbedder(hidden_size)self.blocks = nn.ModuleList([DiTBlock(hidden_size,num_heads,mlp_ratio=mlp_ratio,enable_flash_attn=enable_flash_attn,enable_layernorm_kernel=enable_layernorm_kernel,)for _ in range(depth)])self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)self.initialize_weights()self.enable_flash_attn = enable_flash_attnself.enable_layernorm_kernel = enable_layernorm_kerneldef get_spatial_pos_embed(self):pos_embed = get_2d_sincos_pos_embed(self.hidden_size,self.input_size[1] // self.patch_size[1],)pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)return pos_embeddef get_temporal_pos_embed(self):pos_embed = get_1d_sincos_pos_embed(self.hidden_size,self.input_size[0] // self.patch_size[0],)pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)return pos_embeddef unpatchify(self, x):c = self.out_channelst, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]pt, ph, pw = self.patch_sizex = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))x = rearrange(x, "n t h w r p q c -> n c t r h p w q")imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))return imgsdef forward(self, x, t, y):"""Forward pass of DiT.x: (B, C, T, H, W) tensor of inputst: (B,) tensor of diffusion timestepsy: list of text"""# origin inputs should be float32, cast to specified dtypex = x.to(self.dtype)if self.use_text_encoder:y = y.to(self.dtype)# embeddingx = self.x_embedder(x)  # (B, N, D)x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)x = x + self.pos_embed_spatialif not self.no_temporal_pos_emb:x = rearrange(x, "b t s d -> b s t d")x = x + self.pos_embed_temporalx = rearrange(x, "b s t d -> b (t s) d")else:x = rearrange(x, "b t s d -> b (t s) d")t = self.t_embedder(t, dtype=x.dtype)  # (N, D)y = self.y_embedder(y, self.training)  # (N, D)if self.use_text_encoder:y = y.squeeze(1).squeeze(1)condition = t + y# blocksfor _, block in enumerate(self.blocks):c = conditionx = auto_grad_checkpoint(block, x, c)  # (B, N, D)# final processx = self.final_layer(x, condition)  # (B, N, num_patches * out_channels)x = self.unpatchify(x)  # (B, out_channels, T, H, W)# cast to float32 for better accuracyx = x.to(torch.float32)return xdef initialize_weights(self):# Initialize transformer layers:def _basic_init(module):if isinstance(module, nn.Linear):if module.weight.requires_grad_:torch.nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)self.apply(_basic_init)# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):w = self.x_embedder.proj.weight.datann.init.xavier_uniform_(w.view([w.shape[0], -1]))nn.init.constant_(self.x_embedder.proj.bias, 0)# Initialize timestep embedding MLP:nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)# Zero-out adaLN modulation layers in DiT blocks:for block in self.blocks:nn.init.constant_(block.adaLN_modulation[-1].weight, 0)nn.init.constant_(block.adaLN_modulation[-1].bias, 0)# Zero-out output layers:nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)nn.init.constant_(self.final_layer.linear.weight, 0)nn.init.constant_(self.final_layer.linear.bias, 0)# Zero-out text embedding layers:if self.use_text_encoder:nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/418835.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】MySQL基础

目录 什么是数据库主流数据库基本使用MySQL的安装连接服务器服务器、数据库、表关系使用案例数据逻辑存储 MySQL的架构SQL分类什么是存储引擎 什么是数据库 mysql它是数据库服务的客户端mysqld它是数据库服务的服务器端mysql本质:基于C(mysql&#xff09…

linux系统中,计算两个文件的相对路径

realpath --relative-to/home/itheima/smartnic/smartinc/blocks/ruby/seanet_diamond/tb/parser/test_parser_top /home/itheima/smartnic/smartinc/corundum/fpga/lib/eth/lib/axis/rtl/axis_fifo.v 检验方式就是直接在当前路径下,把输出的路径复制一份&#xff0…

Nginx跨域运行案例:云台控制http请求,通过 http server 代理转发功能,实现跨域运行。(基于大华摄像头WEB无插件开发包)

文章目录 引言I 跨域运行案例开发资源测试/生产环境,Nginx代理转发,实现跨域运行本机开发运行II nginx的location指令Nginx配置中, 获取自定义请求header头Nginx 配置中,获取URL参数引言 背景:全景监控 需求:感知站点由于云台相关操作为 http 请求,http 请求受浏览器…

Redis-主从集群

主从架构 单节点Redis的并发能力是有上限的,要进一步提高Redis的并发能力,就需要搭建主从集群,实现读写分离。 主从数据同步原理 全量同步 主从第一次建立连接时,会执行全量同步,将master节点的所有数据都拷贝给sla…

34465A-61/2 数字万用表(六位半)

34465A-61/2 数字万用表(六位半) 文章目录 34465A-61/2 数字万用表(六位半)前言一、测DC/AC电压二、测DC/AC电流四、测电阻五、测电容六、测二极管七、保存截图流程前言 1、6位半数字万用表通常具有200,000个计数器,可以显示最大为199999的数值。相比普通数字万用表,6位半…

注册安全分析报告:熊猫频道

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

【笔记】Java | 三目运算符和Math函数的比较

实际效果 比较两数并赋值&#xff0c;如下两种方法的耗时不会有差异。 result Math.min(result, subLen);result result < subLen ? result : subLen; 源码解析 因为源码Math.min的源码本质就算三目运算符的比较&#xff0c;所以执行结果是一样的。 三目运算符简介 概…

怎么强制撤销excel工作表保护?

经常不是用的Excel文件设置了工作表保护&#xff0c;偶尔打开文件的时候想要编辑文件&#xff0c;但是发现忘记了密码&#xff0c;那么这种情况&#xff0c;我们怎么强制撤销excel工作表保护&#xff1f;今天分享两种解决方法。 方法一、 将excel文件转换为其他文件格式&…

新品上市丨科学级新款制冷相机sM4040A/sM4040B

sM4040B科学级显微制冷相机 特性 sM4040B搭载了 GSENSE4040BSI 3.2 英寸图像传感器&#xff0c;针对传感器固有的热噪声&#xff0c;专门设计了高效制冷模块&#xff0c;使得相机传感器的工作温度比环境温度低达 35-40 度。针对制冷相机常见的低温结雾现象设计了防结雾机制&a…

二百五十九、Java——采集Kafka数据,解析成一条条数据,写入另一Kafka中(一般JSON)

一、目的 由于部分数据类型频率为1s&#xff0c;从而数据规模特别大&#xff0c;因此完整的JSON放在Hive中解析起来&#xff0c;尤其是在单机环境下&#xff0c;效率特别慢&#xff0c;无法满足业务需求。 而Flume的拦截器并不能很好的转换数据&#xff0c;因为只能采用Java方…

鸿蒙自动化发布测试版本app

创建API客户端 API客户端是AppGallery Connect用于管理用户访问AppGallery Connect API的身份凭据&#xff0c;您可以给不同角色创建不同的API客户端&#xff0c;使不同角色可以访问对应权限的AppGallery Connect API。在访问某个API前&#xff0c;必须创建有权访问该API的API…

UE5.3_跟一个插件—Socket.IO Client

网上看到这个插件,挺好! 项目目前也没有忙到不可开交,索性跟着测一下吧: 商城可见,售价72.61人民币! 但是,git上有仓库哦,免费!! 跟着链接先准备起来: Documentation: GitHub - getnamo/SocketIOClient-Unreal: Socket.IO client plugin for the Unreal Engin…

数据仓库理论知识

1、数据仓库的概念 数据仓库&#xff08;英文&#xff1a;Date Warehouse&#xff0c;简称数仓、DW&#xff09;&#xff0c;是一个用于数据存储、分析、报告的数据系统。数据仓库的建设目的是面向分析的集成化数据环境&#xff0c;其数据来源于不同的外部系统&#…

【H2O2|全栈】Markdown | Md 笔记到底如何使用?【前端 · HTML前置知识】

Markdown的一些杂谈 目录 Markdown的一些杂谈 前言 准备工作 认识.Md文件 为什么使用Md&#xff1f; 怎么使用Md&#xff1f; ​编辑 怎么看别人给我的Md文件&#xff1f; Md文件命令 切换模式 粗体、倾斜、下划线、删除线和荧光标记 分级标题 水平线 引用 无序…

缓存类型以及读写策略

缓存&#xff08;Cache&#xff09;是一种高效的数据存储技术&#xff0c;旨在提高数据访问速度。 它将频繁访问或最近使用的数据临时存储在更快速但较小的存储介质&#xff08;如内存&#xff09;中&#xff0c;以减少从较慢的存储设备&#xff08;如硬盘或远程服务器&#x…

4G模块、WIFI模块、NBIOT模块通过AT指令连接华为云物联网服务器(MQTT协议)

MQTT协议概述 MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;是一种轻量级的消息传输协议&#xff0c;它被设计用来提供一对多的消息分发和应用之间的通讯&#xff0c;尤其适用于远程位置的设备和高延迟或低带宽的网络。MQTT协议基于客户端-服务器架构&…

iOS——方法交换Method Swizzing

什么是方法交换 Method Swizzing是发生在运行时的&#xff0c;主要用于在运行时将两个Method进行交换&#xff0c;我们可以将Method Swizzling代码写到任何地方&#xff0c;但是只有在这段Method Swilzzling代码执行完毕之后互换才起作用。 利用Objective-C Runtimee的动态绑定…

网络编程学习:TCP/IP协议

TCP/IP协议简介 TCP/IP协议包含了一系列的协议&#xff0c;也叫TCP/IP协议族&#xff08;TCP/IP Protocol Suite&#xff0c;或TCP/IP Protocols&#xff09;&#xff0c;简称TCP/IP。 分层结构 为了能够实现不同类型的计算机和不同类型的操作系统之间进行通信&#xff0c;引…

Zookeeper基本原理

1.什么是Zookeeper? Zookeeper是一个开源的分布式协调服务器框架&#xff0c;由Apache软件基金会开发&#xff0c;专为分布式系统设计。它主要用于在分布式环境中管理和协调多个节点之间的配置信息、状态数据和元数据。 Zookeeper采用了观察者模式的设计理念&#xff0c;其核心…

在vscode中用virtual env的方法

vscode是非常常用的软件开发工具。我们也非常了解如何使用vscode开发python的基本方法。当然&#xff0c;vscode可以开发基本所有编程语言。真的是又大又全又好用。 那么为什么要在vscode里面使用virtual env呢&#xff1f;因为python开发会遇到包管理的问题。而virtual env可…