代码解读 | Hybrid Transformers for Music Source Separation[05]

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。

        本篇目标:拆解频域编码模块的底层

        时域编码和频域编码原理类似(后续不再拆解时域编码模块)。

二、频域编码模块


class HEncLayer(nn.Module):def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,rewrite=True):"""Encoder layer. This used both by the time and the frequency branch.Args:chin: number of input channels.chout: number of output channels.norm_groups: number of groups for group norm.empty: used to make a layer with just the first conv. this is usedbefore merging the time and freq. branches.freq: this is acting on frequencies.dconv: insert DConv residual branches.norm: use GroupNorm.context: context size for the 1x1 conv.dconv_kw: list of kwargs for the DConv class.pad: pad the input. Padding is done so that the output size isalways the input size / stride.rewrite: add 1x1 conv at the end of the layer."""super().__init__()norm_fn = lambda d: nn.Identity()  # noqaif norm:norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqaif pad:pad = kernel_size // 4else:pad = 0klass = nn.Conv1dself.freq = freqself.kernel_size = kernel_sizeself.stride = strideself.empty = emptyself.norm = normself.pad = padif freq:kernel_size = [kernel_size, 1]stride = [stride, 1]pad = [pad, 0]klass = nn.Conv2dself.conv = klass(chin, chout, kernel_size, stride, pad)if self.empty:returnself.norm1 = norm_fn(chout)self.rewrite = Noneif rewrite:self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)self.norm2 = norm_fn(2 * chout)self.dconv = Noneif dconv:self.dconv = DConv(chout, **dconv_kw)def forward(self, x, inject=None):"""`inject` is used to inject the result from the time branch into the frequency branch,when both have the same stride."""if not self.freq and x.dim() == 4:B, C, Fr, T = x.shapex = x.view(B, -1, T)if not self.freq:le = x.shape[-1]if not le % self.stride == 0:x = F.pad(x, (0, self.stride - (le % self.stride)))y = self.conv(x)if self.empty:return yif inject is not None:assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)if inject.dim() == 3 and y.dim() == 4:inject = inject[:, :, None]y = y + injecty = F.gelu(self.norm1(y))if self.dconv:if self.freq:B, C, Fr, T = y.shapey = y.permute(0, 2, 1, 3).reshape(-1, C, T)y = self.dconv(y)if self.freq:y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)if self.rewrite:z = self.norm2(self.rewrite(y))z = F.glu(z, dim=1)else:z = yreturn z

        核心代码如上所示。

        使用print函数打印出各个关键节点的信息,可以得到频域编解码模块的全景图。

        编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

        残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())

        +(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#上图均是自己读完代码绘制的。相信自己也可以。
#具体的,编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
#残差连接1
DConv((layers): ModuleList((0): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))#残差连接2
DConv((layers): ModuleList((0): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))#残差连接3
DConv((layers): ModuleList((0): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))#残差连接4
DConv((layers): ModuleList((0): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

        关于,各个卷积模块输出数据的shape计算,可以读这篇文章。

        没有所谓天生的大佬,如果有那么我愿称他/她为圣人。我相信,能读到这儿的都会成为大佬~。Believe yourself,one day,you will be somebody.


         感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/348974.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DDei在线设计器-配置主题风格

DDeiCore-主题 DDei-Core插件提供了默认主题和黑色主题。 如需了解详细的API教程以及参数说明,请参考DDei文档 默认主题 黑色主题 使用指南 引入 import { DDeiCoreThemeBlack } from "ddei-editor";使用并修改设置 extensions: [......//通过配置&am…

君子签帮助物流组织打造线上签约平台,助力简化成本,高效运转

各类物流组织日常业务可能涉及“企业入驻、快递、整车运输、货运、仓储、供应链等”多种类型,各个环节都存在大量的文件/单据签署,网点、客户、司机、收货人遍布全国各地,复杂的签署需求,以及庞大的签字、用印需求,让各…

记录pytest中场景执行的token异常处理问题

前言中写了一个conftest钩子函数用于处理重复调用token的方法,http://t.csdnimg.cn/N4rCK,每个用例单独执行都很正常,但是批量执行时一直报错,token缓存处理也不生效。 所有的用例都报获取不到token,方法改了又改&…

虚拟化 之一 详解 jailhouse 架构及原理、软硬件要求、源码文件、基本组件

Jailhouse 是一个基于 Linux 实现的针对创建工业级应用程序的小型 Hypervisor,是由西门子公司的 Jan Kiszka 于 2013 年开发的,并得到了官方 Linux 内核的支持,在开源社区中获得了知名度和吸引力。 Jailhouse Jailhouse 是一种轻量级的虚拟化…

【复旦邱锡鹏教授《神经网络与深度学习公开课》笔记】感知器

感知器是一种非常早期的线性分类模型,作为一种简单的神经网络模型被提出。感知器是一种模拟生物神经元行为的机器,有与生物神经元相对应的部件,如权重(突触)、偏置(阈值)及激活函数(…

颠覆与创新:探寻Facebook未来的发展路径

Facebook,这个曾经引领社交网络革命的巨头,在如今竞争激烈的科技市场中,正面临着前所未有的挑战和机遇。如何在不断变化的数字世界中保持竞争力,成为业界领先者,这是摆在Facebook面前的重要课题。本文将探寻Facebook未…

STM32开发过程中碰到的问题总结 - 1

文章目录 前言1. 怎么生成keil下可以使用的文件和gcc下编译使用的makefile2. STM32的时钟树3.怎么查看keil5下的编译工具链用的是哪个4. Arm编译工具链和GCC编译工具链有什么区别吗?5. 怎么查看Linux虚拟机是x86的还是aarch646. 怎么下载gcc-arm的编译工具链7.怎么修…

跟着AI学AI_07张量、数组、矩阵

说明这三个概念不是一个范畴的东西,但是很容易混淆,因此放到一起进行说明。 张量(Tensor) 张量是一个多维数组的通用概念,用于表示具有任意维度的数值数据。在数学和计算机科学中,张量是广泛用于表示数据的…

Anime Girls Pack

动漫女孩包 35个动画(就地)支持人形。 8情绪。 角色列表:原艾艾琪惠美子惠理文子星薰和子佳子奈子理子凛老师小樱老师津雨僵尸女孩01 下载:​​Unity资源商店链接资源下载链接 效果图:

字符串排序-第13届蓝桥杯省赛Python真题精选

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第82讲。 字符串排序&#…

阿里云域名解析

阿里云域名控制台:https://dc.console.aliyun.com/next/index#/domain-list/all

table组件,前端如何使用table组件,打印数组数据,后端传输的数据应该如何打印

一、如何使用table,将数组数据打印出来 后端传来的数据,很大概率是一个List数组,我们必须用一个table组件,来打印这些数据。 table标签的介绍 在HTML中,table是常用组件之一,主要用来打印数组信息。 它的…

互联网应用主流框架整合之SpringMVC基础组件开发

多种传参方式 在前一篇文章互联网应用主流框架整合之SpringMVC初始化及各组件工作原理中讨论了最简单的参数传递,而实际情况要复杂的多,比如REST风格,它往往会将参数写入请求路径中,而不是以HTTP请求参数传递;比如查询…

[AI资讯·0612] AI测试高考物理题,最高准确率100%,OpenAI与苹果合作,将ChatGPT融入系统中,大模型在物理领域应用潜力显现

AI资讯 国产AI大战高考物理,第1题全对,第2题开始放飞终于放大招了,2024WWDC,苹果开启AI反击战苹果一夜重塑iPhone!GPT-4o加持Siri,AI深入所有APPOpenAI确认苹果集成ChatGPT 还任命了两位新高管GPT-4搞不定…

大数据可视化电子沙盘:前端技术的全新演绎

随着大数据时代的到来,数据可视化成为了一个重要的技术趋势。数据可视化不仅可以让复杂的数据变得更加直观易懂,还能帮助我们更好地分析和理解数据。在本文中,我们将深入探讨一种基于HTML/CSS/Echarts等技术的大数据可视化电子沙盘&#xff0…

多层tablayout+ViewPager,NestedScrollView+ViewPager+RecyclerView,嵌套吸顶滑动冲突

先看实现的UI效果 其实就是仿BOSS的页面效果,第二层tab下的viewpager滑到最右边再右滑,就操作第一层viewpager滑动。页面上滑时把第一层tab和vp里的banner都推出界面,让第二层tab吸顶。 滑上去第二个tab块卡在顶部,如图 我混乱…

Unity 从0开始编写一个技能编辑器_02_Buff系统的生命周期

工作也有一年了,对技能编辑器也有了一些自己的看法,从刚接触时的惊讶,到大量工作时觉得有一些设计的冗余,在到特殊需求的修改,运行效率低时的优化,技能编辑器在我眼中已经不再是神圣不可攀的存在的&#xf…

redis 06 集群

1.节点,这里是把节点加到集群中的操作,跟主从结构不同 这里是在服务端使用命令: 例子: 2.启动节点 节点服务器 首先,先是服务器节点自身有一个属性来判断是不是可以使用节点功能 一般加入集群中的节点还是用r…

VMware安装ubuntu22.4虚拟机超详细图文教程

一 、下载镜像 下载地址:Index of /ubuntu-releases/22.04.4/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 二、创建虚拟机 打开VMware点击左上角文件,创建新的虚拟机,打开后如下图: 下一步,镜像文件就是…

使用代理IP常见问题有哪些?

代理IP在互联网数据收集和业务开展中发挥着重要作用,它充当用户客户端和网站服务器之间的“屏障”,可以保护用户的真实IP地址,并允许用户通过不同的IP地址进行操作。然而,在使用代理IP的过程中,用户经常会遇到一些问题…