分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):super().__init__()self.ps = patch_sizeself.c1 = Conv2D(in_c, out_c)encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)self.r1 = residual_block(out_c, out_c)def forward(self, inputs):x = self.c1(inputs)b, c, h, w = x.shapenum_patches = (h*w)//(self.ps**2)x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))x = self.te(x)x = torch.reshape(x, (b, c, h, w))x = self.c2(x)s = self.c3(inputs)x = self.relu(x + s)x = self.r1(x)return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):def __init__(self, in_c, out_c):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c),nn.LeakyReLU(negative_slope=0.1, inplace=True),nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c))self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),nn.BatchNorm2d(out_c))self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)def forward(self, inputs):x = self.conv(inputs)s = self.shortcut(inputs)return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):def __init__(self):super().__init__()""" Encoder """backbone = resnet50()self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)self.layer2 = backbone.layer2self.layer3 = backbone.layer3self.layer4 = backbone.layer4self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)""" Decoder """self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)self.r1 = residual_transformer_block(64+64, 64, dim=64)self.r2 = residual_transformer_block(64+64, 64, dim=256)self.r3 = residual_block(64+64, 64)""" Classifier """self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)def forward(self, inputs):""" Encoder """x0 = inputsx1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]e1 = self.e1(x1)e2 = self.e2(x2)e3 = self.e3(x3)e4 = self.e4(x4)""" Decoder """x = self.up(e4)x = torch.cat([x, e3], axis=1)x = self.r1(x)x = self.up(x)x = torch.cat([x, e2], axis=1)x = self.r2(x)x = self.up(x)x = torch.cat([x, e1], axis=1)x = self.r3(x)x = self.up(x)""" Classifier """outputs = self.outputs(x)return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/272983.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

抖音素材网站去哪下载?给你推荐六个抖音自媒体网站

各位抖音视频创作达人们,是否在苦苦寻觅那些能够点燃观众热情,让视频内容跃然屏上的素材宝库呢?此刻,你们的寻觅之旅将迎来终点!我将向你们隆重推荐10个精心挑选的视频素材库,它们定能让你们的抖音视频如同…

【微服务】SpringBoot整合Resilience4j使用详解

目录 一、前言 二、熔断器出现背景 2.1 几个核心概念 2.1.1 熔断 2.1.2 限流 2.1.3 降级 2.2 为什么会出现熔断器 2.3 断路器介绍 2.3.1 断路器原理 三、Resilience4j介绍 3.1 Resilience4j概述 3.1.1 Resilience4j是什么 3.1.2 Resilience4j功能特性 3.2 Resilie…

微服务自动化管理初步认识与使用

目录 一、ETCD 1.1、ETCD简介 对于实施工程师: 1.2、特点 1.3. 使用场景 1.4、 关键字 1.5 工作原理 二、ETCD的安装 2.1、下载路径 2.2、介绍 2.3、具体操作 安装服务端 安装etcd客户端 测试 三、ETCD使用 3.1、前奏具体操作 3.2、 常用操作 一、ET…

利用GPT开发应用001:GPT基础知识及LLM发展

文章目录 一、惊艳的GPT二、大语言模型LLMs三、自然语言处理NLP四、大语言模型LLM发展 一、惊艳的GPT 想象一下,您可以与计算机的交流速度与与朋友交流一样快。那会是什么样子?您可以创建哪些应用程序?这正是OpenAI正在助力构建的世界&#x…

ELFK 分布式日志收集系统

ELFK的组成: Elasticsearch: 它是一个分布式的搜索和分析引擎,它可以用来存储和索引大量的日志数据,并提供强大的搜索和分析功能。 (java语言开发,)logstash: 是一个用于日志收集,处理和传输的…

Linux系统下使用C++推流多路视频流

先看拉取的视频流效果: 代码如下: 一开始打算使用python写多路视频推流,但在ubuntu系统上搞了好久就是搞不定openh264导致的错误,然后改用c了,代码如下,我这里推了两路视频流,一路是网络摄像头&…

2024护网面试题精选(二)完

0x02. 内网渗透篇 00- 内网渗透的流程 拿到跳板后,先探测一波内网存活主机,用net user /domian命令查看跳板机是否在域 内,探测存活主机、提权、提取hash、进行横向移动,定位dc位置,查看是否有能直接提权域 管的漏洞…

pytorch什么是梯度

目录 1.导数、偏微分、梯度1.1 导数1.2 偏微分1.3 梯度 2. 通过梯度求极小值3. learning rate3. 局部最小值4. Saddle point鞍点 1.导数、偏微分、梯度 1.1 导数 对于yx 2 2 2 的导数,描述了y随x值变化的一个变化趋势,导数是个标量反应的是变化的程度&…

【HTML】HTML基础7.1(无序列表)

目录 标签 属性 效果 注意 标签 <ul> <li>列表里要装的东西</li> <li>列表里要装的东西</li> <li>列表里要装的东西</li> </ul> 属性 type&#xff1a; circle空心圆disc实心圆square方框 效果 circle空心圆效果…

vi/vim编辑器

vi/vim编辑器 vi的特点与运用场景vi的使用简易执行一个案例按键说明第一部分&#xff1a;命令模式的按键说明(光标移动、复制粘贴、查找替换)第二部分&#xff1a;命令模式切换到输入模式的可以按键第三部分&#xff1a;命令模式切换到底线命令模式的可用按键 命令行模式的保存…

【fastllm】学习框架,本地运行,速度还可以,可以成功运行chatglm2模型

1&#xff0c;关于 fastllm 项目 https://www.bilibili.com/video/BV1fx421k7Mz/?vd_source4b290247452adda4e56d84b659b0c8a2 【fastllm】学习框架&#xff0c;本地运行&#xff0c;速度还可以&#xff0c;可以成功运行chatglm2模型 https://github.com/ztxz16/fastllm &am…

ai学习前瞻-python环境搭建

python环境搭建 Python环境搭建1. python的安装环境2. MiniConda安装3. pycharm安装4. Jupyter 工具安装5. conda搭建虚拟环境6. 安装python模块pip安装conda安装 7. 关联虚拟环境运行项目 Python环境搭建 1. python的安装环境 ​ python环境安装有4中方式。 从上图可以了解…

YOLO语义分割标注文件txt还原到图像中

最近做图像分割任务过程中&#xff0c;使用labelme对图像进行标注&#xff0c;得到的数据文件是json&#xff0c;转换为YOLO训练所需的txt格式后&#xff0c;想对标注文件进行检验&#xff0c;即将txt标注文件还原到原图像中&#xff0c;下面是代码&#xff1a; import cv2 im…

C++指针(五)完结篇

个人主页&#xff1a;PingdiGuo_guo 收录专栏&#xff1a;C干货专栏 前言 相关文章&#xff1a;C指针&#xff08;一&#xff09;、C指针&#xff08;二&#xff09;、C指针&#xff08;三&#xff09;、C指针&#xff08;四&#xff09;万字图文详解&#xff01; 本篇博客是介…

交易平台开发:构建安全/高效/用户友好的在线交易生态圈

在数字化浪潮的推动下&#xff0c;农产品现货大宗商品撮合交易平台已成为连接全球买家与卖家的核心枢纽。随着电子商务的飞速发展&#xff0c;一个安全、高效、用户友好的交易平台对于促进交易、提升用户体验和增加用户黏性至关重要。本文将深入探讨交易平台开发的关键要素&…

git学习(创建项目提交代码)

操作步骤如下 git init //初始化git remote add origin https://gitee.com/aydvvs.git //建立连接git remote -v //查看git add . //添加到暂存区git push 返送到暂存区git status // 查看提交代码git commit -m初次提交git push -u origin "master"//提交远程分支 …

Pytorch学习 day09(简单神经网络模型的搭建)

简单神经网络模型的搭建 针对CIFAR 10数据集的神经网络模型结构如下图&#xff1a; 由于上图的结构没有给出具体的padding、stride的值&#xff0c;所以我们需要根据以下公式&#xff0c;手动推算&#xff1a; 注意&#xff1a;当stride太大时&#xff0c;padding也会变得很大…

视频推拉流EasyDSS平台直播通道重连无法转推的原因排查与解决

视频推拉流EasyDSS视频直播点播平台&#xff0c;集视频直播、点播、转码、管理、录像、检索、时移回看等功能于一体&#xff0c;可提供音视频采集、视频推拉流、播放H.265编码视频、存储、分发等视频能力服务。 用户使用EasyDSS平台对直播通道进行转推&#xff0c;发现只要关闭…

AOP切面编程,以及自定义注解实现切面

AOP切面编程 通知类型表达式重用表达式切面优先级使用注解开发&#xff0c;加上注解实现某些功能 简介 动态代理分为JDK动态代理和cglib动态代理当目标类有接口的情况使用JDK动态代理和cglib动态代理&#xff0c;没有接口时只能使用cglib动态代理JDK动态代理动态生成的代理类…

【滑动窗口】力扣239.滑动窗口最大值

前面的文章我们练习数十道 动态规划 的题目。相信小伙伴们对于动态规划的题目已经写的 得心应手 了。 还没看过的小伙伴赶快关注一下&#xff0c;学习如何 秒杀动态规划 吧&#xff01; 接下来我们开启一个新的篇章 —— 「滑动窗口」。 滑动窗口 滑动窗口 是一种基于 双指…