【代码】Swan-Transformer 代码详解(待完成)

1. 局部注意力  Window Attention (W-MSA Module)

class WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size  # [Mh, Mw]print(self.window_size)self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask: Optional[torch.Tensor] = None):"""Args:x: input features with shape of (num_windows*B, Mh*Mw, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""# [batch_size*num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]q = q * self.scaleattn = (q @ k.transpose(-2, -1))# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:# mask: [nW, Mh*Mw, Mh*Mw]nW = mask.shape[0]  # num_windows# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]# mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/405673.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用户画像实时标签数据处理流程图

背景 在用户画像中,有一类实时标签,我们既要它能够实时的对外提供数据统计,也要保存到大数据组件中用于后续的对数,圈选的逻辑,本文就看一下用户画像的实时标签的数据流转图 实时标签数据流转图 首先我们肯定是要使…

独立站PrestaShop安装

独立站PrestaShop安装 独立站PrestaShop安装系统需求下载PrestaShop浏览器下载命令行下载 解压PrestaShop创建数据库移动PrestaShop源码到web目录composer安装依赖包nginx配置访问域名进入安装页面选择语言许可协议系统兼容性店铺信息Content of your store系统配置数据库店铺安…

书生大模型学习笔记9 - LMDeploy 量化部署

LMDeploy 量化部署 InternLM 2.5 20b量化前部署W4A16 模型量化量化模型部署streamlit web InternLM 2.5 20b量化前部署 lmdeploy serve api_server \/root/learning/InternLM/XTuner/merged_20b \--model-format hf \--quant-policy 0 \--cache-max-entry-count 0.01\--server…

数据结构与算法——图

1、为什么要有图 1)前面我们学习了线性表和树 2)线性表局限于一个直接前驱和一个直接后继的关系 3)树也只能有一个直接前驱就是父节点 4)当我们需要表示多对多的关系时,这里我们就用到了图 图是一种数据结构&#xf…

支持2.4G频秒变符合GB42590的标准的飞行器【无人机GB42590发射端】

使用方法: 放在飞机 上,按键那一面需要朝上对着天空(因为GPS陶瓷天线在按键面),支持基本ID,向量和系统包,电池容量240mAH充电1小时,使用时间大概2小时。 1.长按3秒开关机 2.开机红灯慢闪,只发射基本ID数据…

JavaScript_7_练习:随机抽奖案例

效果图 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>练习&#xff1a;随机抽奖案例</tit…

【后续更新】python搜集上海二手房数据

源码如下: import asyncio import aiohttp from lxml import etree import logging import datetime import openpyxlwb = openpyxl.Workbook() sheet = wb.active sheet.append([房源, 房子信息, 所在区域, 单价, 关注人数和发布时间, 标签]) logging.basicConfig(level=log…

GD32双路CAN踩坑记录

GD32双路CAN踩坑记录 目录 GD32双路CAN踩坑记录1 问题描述2 原因分析3 解决办法4 CAN配置参考代码 1 问题描述 GD32的CAN1无法进入接收中断&#xff0c;收不到数据。 注&#xff1a;MCU使用的是GD32E50x&#xff0c;其他型号不确定是否一样&#xff0c;本文只以GD32E50x举例说…

CTF中的换表类Crypto题目

目录 [安洵杯 2019]JustBase[SWPUCTF 2021 新生赛]traditional字符替换解密 [BJDCTF 2020]base??字符替换 --》 base64解密 [安洵杯 2019]JustBase VGhlIGdlbxvZ#kgbYgdGhlIEVhcnRoJ#Mgc#VyZmFjZSBpcyBkb!pbmF)ZWQgYnkgdGhlIHBhcnRpY#VsYXIgcHJvcGVydGllcyBvZiB#YXRlci$gUHJ…

计量自动化终端上行通信规约

物理层 TCP 和 UDP 的传输接口 该类接口的登录链接和心跳检测采用链路测试服务&#xff0c;链路测试周期可设定。 参见 TCP/IP 协议规范。 串行通信传输接口 字节传输按异步方式进行&#xff0c;它包含 8 个数据位、1 个起始位“0”、1 个偶校验位 P 和 1 个停止位“1”。 …

测试流程自动化实践!

测试流程自动化的最佳实践涉及多个方面&#xff0c;旨在提高测试效率、确保测试质量&#xff0c;并降低测试成本。以下是一些关键的实践方法&#xff1a; 1. 明确测试目标 确定测试范围&#xff1a;在开始自动化测试之前&#xff0c;需要明确哪些功能、模块或场景需要被测试。…

共享内存及网络通信

共享内存 ------ 最高效的进程间通信 一个内核预留的空间&#xff0c;两进程绑定同一块共享空间 避免了用户空间 到 内核空间的数据拷贝 IPC 操作流程 key值 > 申请 >读写 >关闭 >卸载 1,产生key值 函数ftok key_t ftok(const char *pathname, int proj_id);…

Python 设置Excel工作表页边距、纸张大小/方向、打印区域、缩放比例

在使用Excel进行数据分析或报告制作时&#xff0c;页面设置是确保最终输出效果专业、美观的关键步骤。合理的页面设置不仅能够优化打印效果&#xff0c;还能提升数据的可读性。本文将详细介绍如何使用Python操作Excel中的各项页面设置功能。 目录 Python 设置Excel工作表页边…

调用大模型API-文心一言

一、准备工作 进入百度智能云千帆大模型平台&#xff0c;点击应用接入-创建应用&#xff1b;按提默认完成创建 二、开始使用 单轮调用 进入API列表 - ModelBuilder以第一个ERNIE-4.0-8K为例&#xff0c;选择“HTTP请求调用”&#xff0c;把第一步创建应用的 应用API Key、应…

汽车服务管理系统 _od8kr

TOC springboot580汽车服务管理系统 _od8kr--论文 系统概述 该系统由个人管理员和员工管理&#xff0c;用户三部分组成。其中&#xff1a;用户进入系统首页可以实现首页&#xff0c;热销汽车&#xff0c;汽车配件&#xff0c;汽车资讯&#xff0c;后台管理&#xff0c;在线客…

微服务的基本理解和使用

目录​​​​​​​ 一、微服务基础知识 1、系统架构的演变 &#xff08;1&#xff09;单体应用架构 &#xff08;2&#xff09;垂直应用架构 &#xff08;3&#xff09;分布式SOA架构 &#xff08;4&#xff09;微服务架构 &#xff08;5&#xff09;SOA与微服务的关系…

什么是UDP?

UDP是工作在OSI&#xff08;开放系统互连&#xff0c;Open Systems Interconnection&#xff09;模型中传输层的协议。它使用IP作为底层协议&#xff0c;是为应用程序提供一种以最少的协议机制向其他程序发送消息的协议。其主要特点是无连接&#xff0c;不保证可靠传输和面向报…

[机器学习]--线性回归算法

线性回归算法原理 线性关系在生活中有很多案例: 摄氏度和华氏度的转化: F C ⋅ 9 5 32 F C \cdot\frac{9}{5}32 FC⋅59​32学科最终成绩的计算: 最终成绩 0.3 \times 平时成绩 0.7 \times 期末成绩 线性回归(Linear regression)就是利用回归函数对一个或多个自变量…

Qt中英文支持

目的 就是想让QT编的软件支持中英文。 情况 1、首先配置项目的pro文件&#xff1a; 这样就会生成相应的翻译配置文件&#xff0c;当前是&#xff1a; translate1_cn.ts&#xff1a;中文的配置文件&#xff0c;因为一般默认就是中文&#xff0c;所以一般中文的翻译文件是不需…

小程序商城被盗刷,使用SCDN安全加速有用吗?

在电子商务蓬勃发展的今天&#xff0c;小程序商城因其便捷性和灵活性成为商家和消费者的新宠。然而&#xff0c;随着其普及&#xff0c;小程序商城的安全问题也日益凸显&#xff0c;尤其是盗刷现象频发&#xff0c;给商家和用户带来了巨大损失。面对这一挑战&#xff0c;是否可…