每日Attention学习22——Inverted Residual RWKV

模块出处

[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation


模块名称

Inverted Residual RWKV (IR-RWKV)


模块作用

用于vision的RWKV结构


模块结构

在这里插入图片描述


模块代码

注:cpp扩展请参考作者原仓库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import loadT_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])def get_norm(norm_layer='in_1d'):eps = 1e-6norm_dict = {'none': nn.Identity,'in_1d': partial(nn.InstanceNorm1d, eps=eps),'in_2d': partial(nn.InstanceNorm2d, eps=eps),'in_3d': partial(nn.InstanceNorm3d, eps=eps),'bn_1d': partial(nn.BatchNorm1d, eps=eps),'bn_2d': partial(nn.BatchNorm2d, eps=eps),# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),'bn_3d': partial(nn.BatchNorm3d, eps=eps),'gn': partial(nn.GroupNorm, eps=eps),'ln_1d': partial(nn.LayerNorm, eps=eps),# 'ln_2d': partial(LayerNorm2d, eps=eps),}return norm_dict[norm_layer]def get_act(act_layer='relu'):act_dict = {'none': nn.Identity,'sigmoid': Sigmoid,'swish': Swish,'mish': Mish,'hsigmoid': HardSigmoid,'hswish': HardSwish,'hmish': HardMish,'tanh': Tanh,'relu': nn.ReLU,'relu6': nn.ReLU6,'prelu': PReLU,'gelu': GELU,'silu': nn.SiLU}return act_dict[act_layer]class ConvNormAct(nn.Module):def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):super(ConvNormAct, self).__init__()self.has_skip = skip and dim_in == dim_outpadding = math.ceil((kernel_size - stride) / 2)self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)self.norm = get_norm(norm_layer)(dim_out)self.act = get_act(act_layer)(inplace=inplace)self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()def forward(self, x):shortcut = xx = self.conv(x)x = self.norm(x)x = self.act(x)if self.has_skip:x = self.drop_path(x) + shortcutreturn xclass SE(nn.Module):def __init__(self,in_chs: int,rd_ratio: float = 0.25,rd_channels: Optional[int] = None,act_layer: LayerType = nn.ReLU,gate_layer: LayerType = nn.Sigmoid,force_act_layer: Optional[LayerType] = None,rd_round_fn: Optional[Callable] = None,):super(SE, self).__init__()if rd_channels is None:rd_round_fn = rd_round_fn or roundrd_channels = rd_round_fn(in_chs * rd_ratio)act_layer = force_act_layer or act_layerself.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)self.act1 = create_act_layer(act_layer, inplace=True)self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)self.gate = create_act_layer(gate_layer)def forward(self, x):x_se = x.mean((2, 3), keepdim=True)x_se = self.conv_reduce(x_se)x_se = self.act1(x_se)x_se = self.conv_expand(x_se)return x * self.gate(x_se)def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):assert gamma <= 1/4B, N, C = input.shapeinput = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])B, C, H, W = input.shapeoutput = torch.zeros_like(input)output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]return output.flatten(2).transpose(1, 2)def RUN_CUDA(B, T, C, w, u, k, v):return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())class WKV(torch.autograd.Function):@staticmethoddef forward(ctx, B, T, C, w, u, k, v):ctx.B = Bctx.T = Tctx.C = Cassert T <= T_MAXassert B * C % min(C, 1024) == 0half_mode = (w.dtype == torch.half)bf_mode = (w.dtype == torch.bfloat16)ctx.save_for_backward(w, u, k, v)w = w.float().contiguous()u = u.float().contiguous()k = k.float().contiguous()v = v.float().contiguous()y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)wkv_cuda.forward(B, T, C, w, u, k, v, y)if half_mode:y = y.half()elif bf_mode:y = y.bfloat16()return y@staticmethoddef backward(ctx, gy):B = ctx.BT = ctx.TC = ctx.Cassert T <= T_MAXassert B * C % min(C, 1024) == 0w, u, k, v = ctx.saved_tensorsgw = torch.zeros((B, C), device='cuda').contiguous()gu = torch.zeros((B, C), device='cuda').contiguous()gk = torch.zeros((B, T, C), device='cuda').contiguous()gv = torch.zeros((B, T, C), device='cuda').contiguous()half_mode = (w.dtype == torch.half)bf_mode = (w.dtype == torch.bfloat16)wkv_cuda.backward(B, T, C,w.float().contiguous(),u.float().contiguous(),k.float().contiguous(),v.float().contiguous(),gy.float().contiguous(),gw, gu, gk, gv)if half_mode:gw = torch.sum(gw.half(), dim=0)gu = torch.sum(gu.half(), dim=0)return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())elif bf_mode:gw = torch.sum(gw.bfloat16(), dim=0)gu = torch.sum(gu.bfloat16(), dim=0)return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())else:gw = torch.sum(gw, dim=0)gu = torch.sum(gu, dim=0)return (None, None, None, gw, gu, gk, gv)class VRWKV_SpatialMix(nn.Module):def __init__(self, n_embd, channel_gamma=1/4, shift_pixel=1):super().__init__()self.n_embd = n_embdattn_sz = n_embdself._init_weights()self.shift_pixel = shift_pixelif shift_pixel > 0:self.channel_gamma = channel_gammaelse:self.spatial_mix_k = Noneself.spatial_mix_v = Noneself.spatial_mix_r = Noneself.key = nn.Linear(n_embd, attn_sz, bias=False)self.value = nn.Linear(n_embd, attn_sz, bias=False)self.receptance = nn.Linear(n_embd, attn_sz, bias=False)self.key_norm = nn.LayerNorm(n_embd)self.output = nn.Linear(attn_sz, n_embd, bias=False)self.key.scale_init = 0self.receptance.scale_init = 0self.output.scale_init = 0def _init_weights(self):self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)def jit_func(self, x, patch_resolution):# Mix x with the previous timestep to produce xk, xv, xrB, T, C = x.size()# Use xk, xv, xr to produce k, v, rif self.shift_pixel > 0:xx = q_shift(x, self.shift_pixel, self.channel_gamma, patch_resolution)xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)else:xk = xxv = xxr = xk = self.key(xk)v = self.value(xv)r = self.receptance(xr)sr = torch.sigmoid(r)return sr, k, vdef forward(self, x, patch_resolution=None):B, T, C = x.size()sr, k, v = self.jit_func(x, patch_resolution)x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)x = self.key_norm(x)x = sr * xx = self.output(x)return xclass iR_RWKV(nn.Module):def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',act_layer='relu', dw_ks=3, stride=1, dilation=1, se_ratio=0.0,attn_s=True, drop_path=0., drop=0.,img_size=224, channel_gamma=1/4, shift_pixel=1):super().__init__()self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()dim_mid = int(dim_in * exp_ratio)self.ln1 = nn.LayerNorm(dim_mid)self.conv = ConvNormAct(dim_in, dim_mid, kernel_size=1)self.has_skip = (dim_in == dim_out and stride == 1) and has_skipif attn_s==True:self.att = VRWKV_SpatialMix(dim_mid, channel_gamma, shift_pixel)self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()self.proj_drop = nn.Dropout(drop)self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()self.attn_s=attn_sself.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)def forward(self, x):shortcut = xx = self.norm(x)x = self.conv(x)if self.attn_s:B, hidden, H, W = x.size()patch_resolution = (H,  W)x = x.view(B, hidden, -1)  # (B, hidden, H*W) = (B, C, N)x = x.permute(0, 2, 1)x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))B, n_patch, hidden = x.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hiddeh, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))x = x.permute(0, 2, 1)x = x.contiguous().view(B, hidden, h, w)x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))x = self.proj_drop(x)x = self.proj(x)x = (shortcut + self.drop_path(x)) if self.has_skip else xreturn xif __name__ == '__main__':x = torch.randn([1, 64, 11, 11]).cuda()ir_rwkv = iR_RWKV(dim_in=64, dim_out=64).cuda()out = ir_rwkv(x)print(out.shape)  # [1, 64, 11, 11]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/15293.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vscode预览插件

在左侧列表拓展里搜索 Live Preview 安装&#xff0c;然后在html页面点击右键找到show Preview 结果如下图 然后就可以进行代码开发并实时预览了

【04】Java+若依+vue.js技术栈实现钱包积分管理系统项目-若依框架二次开发准备工作-以及建立初步后端目录菜单列-优雅草卓伊凡商业项目实战

【04】Java若依vue.js技术栈实现钱包积分管理系统项目-若依框架二次开发准备工作-以及建立初步后端目录菜单列-优雅草卓伊凡商业项目实战 项目背景 本项目经费43000元&#xff0c;需求文档如下&#xff0c;工期25天&#xff0c;目前已经过了8天&#xff0c;时间不多了&#x…

【DeepSeek】DeepSeek概述 | 本地部署deepseek

目录 1 -> 概述 1.1 -> 技术特点 1.2 -> 模型发布 1.3 -> 应用领域 1.4 -> 优势与影响 2 -> 本地部署 2.1 -> 安装ollama 2.2 -> 部署deepseek-r1模型 1 -> 概述 DeepSeek是由中国的深度求索公司开发的一系列人工智能模型&#xff0c;以其…

Windows下AMD显卡在本地运行大语言模型(deepseek-r1)

Windows下AMD显卡在本地运行大语言模型 本人电脑配置第一步先在官网确认自己的 AMD 显卡是否支持 ROCm下载Ollama安装程序模型下载位置更改下载 ROCmLibs先确认自己显卡的gfx型号下载解压 替换替换rocblas.dll替换library文件夹下的所有 重启Ollama下载模型运行效果 本人电脑配…

使用Pytorch训练一个图像分类器

一、准备数据集 一般来说&#xff0c;当你不得不与图像、文本或者视频资料打交道时&#xff0c;会选择使用python的标准库将原始数据加载转化成numpy数组&#xff0c;甚至可以继续转换成torch.*Tensor。 对图片而言&#xff0c;可以使用Pillow库和OpenCV库对视频而言&#xf…

DeepSeek之Api的使用(将DeepSeek的api集成到程序中)

一、DeepSeek API 的收费模式 前言&#xff1a;使用DeepSeek的api是收费的 免费版&#xff1a; 可能提供有限的免费额度&#xff08;如每月一定次数的 API 调用&#xff09;&#xff0c;适合个人开发者或小规模项目。 付费版&#xff1a; 超出免费额度后&#xff0c;可能需要按…

git fetch和git pull 的区别

git pull 实际上就是 fetch merge 的缩写, git pull 唯一关注的是提交最终合并到哪里&#xff08;也就是为 git fetch 所提供的 destination 参数&#xff09; git fetch 从远程仓库下载本地仓库中缺失的提交记录,并更新远程分支指针 git pull抓取更新再合并到本地分支,相当于…

信息科技伦理与道德3-2:智能决策

2.2 智能推荐 推荐算法介绍 推荐系统&#xff1a;猜你喜欢 https://blog.csdn.net/search_129_hr/article/details/120468187 推荐系统–矩阵分解 https://blog.csdn.net/search_129_hr/article/details/121598087 案例一&#xff1a;YouTube推荐算法向儿童推荐不适宜视频 …

[LVGL] 在VC_MFC中移植LVGL

前言&#xff1a; 0. 在MFC中开发LVGL的优点是可以用多个Window界面做辅助扩展【类似GUIguider】 1.本文基于VC2022-MFC单文档框架移植lvgl8 2. gitee上下载lvgl8.3 源码&#xff0c;并将其文件夹改名为lvgl lvgl: LVGL 是一个开源图形库&#xff0c;提供您创建具有易于使用…

[RabbitMQ] RabbitMQ常见面试题

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…

《qt easy3d中添加孔洞填充》

《qt easy3d中添加孔洞填充》 效果展示一、创建流程二、核心代码效果展示 参考链接Easy3D开发——点云孔洞填充 一、创建流程 创建动作,并转到槽函数,并将动作放置菜单栏,可以参考前文 其中,槽函数on_actionHoleFill_triggered实现如下:

Git(分布式版本控制系统)系统学习笔记【并利用腾讯云的CODING和Windows上的Git工具来实操】

Git的概要介绍 1️⃣ Git 是什么&#xff1f; Git 是一个 分布式版本控制系统&#xff08;DVCS&#xff09;&#xff0c;用于跟踪代码的变更、协作开发和管理项目历史。 由 Linus Torvalds&#xff08;Linux 之父&#xff09;在 2005 年开发&#xff0c;主要用于 代码管理。…

基于SpringBoot的校园社交平台

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

R语言LCMM多维度潜在类别模型流行病学研究:LCA、MM方法分析纵向数据

全文代码数据&#xff1a;https://tecdat.cn/?p39710 在数据分析领域&#xff0c;当我们面对一组数据时&#xff0c;通常会有已知的分组情况&#xff0c;比如不同的治疗组、性别组或种族组等&#xff08;点击文末“阅读原文”获取完整代码数据&#xff09;。 然而&#xff0c;…

mysql 主从配置

MySQL 主从复制是指在 MySQL 数据库系统中&#xff0c;主服务器&#xff08;Master&#xff09;将数据更新操作&#xff08;如 INSERT、UPDATE、DELETE&#xff09;复制到从服务器&#xff08;Slave&#xff09;。主从复制实现了数据的同步复制&#xff0c;使得从服务器可以保持…

DeepSeek为何能爆火

摘要&#xff1a;近年来&#xff0c;DeepSeek作为一款新兴的社交媒体应用&#xff0c;迅速在年轻人群体中走红&#xff0c;引发了广泛关注。本文旨在探讨DeepSeek为何能在短时间内爆火&#xff0c;从而为我国社交媒体的发展提供参考。首先&#xff0c;通过文献分析&#xff0c;…

黑马React保姆级(PPT+笔记)

一、react基础 1.进程 2、优势 封装成一个库&#xff0c;组件化开发更加方便 跨平台主要是react native等可以来写移动端如android&#xff0c;ios等 丰富生态&#xff1a;可以在很多浏览器用 3、市场 4、搭建脚手架 npx create-react-app react-basic npm start后仍然可能…

STM32 CUBE Can调试

STM32 CUBE Can调试 1、CAN配置2、时钟配置3、手动添加4、回调函数5、启动函数和发送函数6、使用方法(采用消息队列来做缓存)7、数据不多在发送函数中获取空邮箱发送&#xff0c;否则循环等待空邮箱 1、CAN配置 2、时钟配置 3、手动添加 需要注意的是STM32CUBE配置的代码需要再…

DeepSeek从入门到精通:全面掌握AI大模型的核心能力

文章目录 一、DeepSeek是什么&#xff1f;性能对齐OpenAI-o1正式版 二、Deepseek可以做什么&#xff1f;能力图谱文本生成自然语言理解与分析编程与代码相关常规绘图 三、如何使用DeepSeek&#xff1f;四、DeepSeek从入门到精通推理模型推理大模型非推理大模型 快思慢想&#x…

【vscode+latex】实现overleaf本地高效编译

overleaf本地高效编译 1. 配置本地latex环境2. vscode插件与配置3. 使用 之前觉得用overleaf在线写论文很方便&#xff0c;特别是有辅助生成latex格式公式的网页&#xff0c;不需要在word上一个一个手打调格式。 然而&#xff0c;最近在写一篇论文的时候&#xff0c;由于这篇论…