pytorch复现_UNet

什么是UNet
U-Net由收缩路径和扩张路径组成。收缩路径是一系列卷积层和汇集层,其中要素地图的分辨率逐渐降低。扩展路径是一系列上采样层和卷积层,其中特征地图的分辨率逐渐增加。
在扩展路径中的每一步,来自收缩路径的对应特征地图与当前特征地图级联。
在这里插入图片描述
主干结构解析
左边为特征提取网络(编码器),右边为特征融合网络(解码器)

高分辨率—编码—低分辨率—解码—高分辨率

特征提取网络
高分辨率—编码—低分辨率

前半部分是编码, 它的作用是特征提取(获取局部特征,并做图片级分类),得到抽象语义特征

由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块,一共经过4次这样的操作

特征融合网络
低分辨率—解码—高分辨率

利用前面编码的抽象特征来恢复到原图尺寸的过程, 最终得到分割结果(掩码图片)

代码:

import torch.nn as nn
import torch# 编码器(论文中称之为收缩路径)的基本单元
def contracting_block(in_channels, out_channels):block = torch.nn.Sequential(# 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels, out_channels=out_channels),nn.BatchNorm2d(out_channels),nn.ReLU())return block# 解码器(论文中称之为扩张路径)的基本单元
class expansive_block(nn.Module):def __init__(self, in_channels, mid_channels, out_channels):super(expansive_block, self).__init__()# 每进行一次反卷积,通道数减半,尺寸扩大2倍self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=(3, 3), stride=2, padding=1,output_padding=1)self.block = nn.Sequential(# 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels),nn.BatchNorm2d(mid_channels),nn.ReLU(),nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels),nn.BatchNorm2d(out_channels),nn.ReLU())def forward(self, e, d):d = self.up(d)# concat# e是来自编码器部分的特征图,d是来自解码器部分的特征图,它们的形状都是[B,C,H,W]diffY = e.size()[2] - d.size()[2]diffX = e.size()[3] - d.size()[3]# 裁剪时,先计算e与d在高和宽方向的差距diffY和diffX,然后对e高方向进行裁剪,具体方法是两边分别裁剪diffY的一半,# 最后对e宽方向进行裁剪,具体方法是两边分别裁剪diffX的一半,# 具体的裁剪过程见下图一e = e[:, :, diffY // 2:e.size()[2] - diffY // 2, diffX // 2:e.size()[3] - diffX // 2]cat = torch.cat([e, d], dim=1)  # 在特征通道上进行拼接out = self.block(cat)return out# 最后的输出卷积层
def final_block(in_channels, out_channels):block = nn.Conv2d(kernel_size=(1, 1), in_channels=in_channels, out_channels=out_channels)return blockclass UNet(nn.Module):def __init__(self, in_channel, out_channel):super(UNet, self).__init__()# 编码器 (Encode)self.conv_encode1 = contracting_block(in_channels=in_channel, out_channels=64)self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)# 编码器与解码器之间的过渡部分(Bottleneck)self.bottleneck = nn.Sequential(nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024),nn.BatchNorm2d(1024),nn.ReLU(),nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024),nn.BatchNorm2d(1024),nn.ReLU())# 解码器(Decode)self.conv_decode4 = expansive_block(1024, 512, 512)self.conv_decode3 = expansive_block(512, 256, 256)self.conv_decode2 = expansive_block(256, 128, 128)self.conv_decode1 = expansive_block(128, 64, 64)self.final_layer = final_block(64, out_channel)def forward(self, x):# Encodeencode_block1 = self.conv_encode1(x)encode_pool1 = self.conv_pool1(encode_block1)encode_block2 = self.conv_encode2(encode_pool1)encode_pool2 = self.conv_pool2(encode_block2)encode_block3 = self.conv_encode3(encode_pool2)encode_pool3 = self.conv_pool3(encode_block3)encode_block4 = self.conv_encode4(encode_pool3)encode_pool4 = self.conv_pool4(encode_block4)# Bottleneckbottleneck = self.bottleneck(encode_pool4)# Decodedecode_block4 = self.conv_decode4(encode_block4, bottleneck)decode_block3 = self.conv_decode3(encode_block3, decode_block4)decode_block2 = self.conv_decode2(encode_block2, decode_block3)decode_block1 = self.conv_decode1(encode_block1, decode_block2)final_layer = self.final_layer(decode_block1)return final_layerif __name__ == '__main__':image = torch.rand((1, 3, 572, 572))unet = UNet(in_channel=3, out_channel=2)mask = unet(image)print(mask.shape)#输出结果:torch.Size([1, 2, 388, 388])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/186301.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

css设置浏览器表单自动填充时的背景

浏览器自动填充表单内容,会自动设置背景色。对于一般的用户,也许不会觉得有什么,但对于要求比较严格的用户,就会“指手画脚”。这里,我们通过css属性来设置浏览器填充背景的过渡时间,使用户看不到过渡后的背…

win10下.net framework 3.5 | net framework 4 无法安装解决方案

.net缺失解决方案 win10 .net framework 3.5组策略设置方案一方案二 win10 .net framework 4 参考文章 win10 .net framework 3.5 组策略设置 方案一 搜索组策略,依次展开“计算机配置”、“管理模板”,然后选择“系统”,找到指定可选组件…

Leetcode2246. 相邻字符不同的最长路径

Every day a Leetcode 题目来源:2246. 相邻字符不同的最长路径 解法1:树形 DP 如果没有相邻节点的限制,那么本题求的就是树的直径上的点的个数,见于Leetcode543. 二叉树的直径。 考虑用树形 DP 求直径。 枚举子树 x 的所有子…

如何让群晖Audio Station公开共享的本地音频公网可访问?

文章目录 1. 本教程使用环境:2. 制作音频分享链接3. 制作永久固定音频分享链接: 之前文章我详细介绍了如何在公网环境下使用pc和移动端访问群晖Audio Station: 公网访问群晖audiostation听歌 - cpolar 极点云 群晖套件不仅能读写本地文件&a…

Go基础知识全面总结

文章目录 go基本数据类型bool类型数值型字符字符串 数据类型的转换运算符和表达式1. 算数运算符2.关系运算符3. 逻辑运算符4. 位运算符5. 赋值运算符6. 其他运算符运算符优先级转义符 go基本数据类型 bool类型 布尔型的值只可以是常量 true 或者 false。⼀个简单的例⼦&#…

竞赛选题 深度学习猫狗分类 - python opencv cnn

文章目录 0 前言1 课题背景2 使用CNN进行猫狗分类3 数据集处理4 神经网络的编写5 Tensorflow计算图的构建6 模型的训练和测试7 预测效果8 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 **基于深度学习猫狗分类 ** 该项目较为新颖&a…

Python基础教程之十六:Python multidict示例–将单个键映射到字典中的多个值

1.什么是multidict词典> 在python中,“ multidict ”一词用于指代字典,在字典中可以将单个键映射到多个值。例如 多重结构 multidictWithList {key1 : [1, 2, 3],key2 : [4, 5]}multidictWithSet {key1 : {1, 2, 3},key2 : {4, 5}}1. list如果要…

内核移植笔记 Cortex-M移植

常用寄存器 PRIMASK寄存器 为1位宽的中断屏蔽寄存器。在置位时,它会阻止不可屏蔽中断(NMI)和HardFault异常之外的所有异常(包括中断)。 实际上,它是将当前异常优先级提升为0,这也是可编程异常/…

uniapp使用vue3和ts开发小程序自定义tab栏,实现自定义凸出tabbar效果

要实现自定义的tabbar效果,可以使用自定义tab覆盖主tab来实现,当程序启动或者从后台显示在前台时隐藏自带的tab来实现。自定义一个tab组件,然后在里面实现自定义的逻辑。 组件中所使用的组件api可以看:Tabbar 底部导航栏 | uView…

Centos7下搭建H3C log服务器

rsyslogH3C 安装rsyslog服务器 关闭防火墙 systemctl stop firewalld && systemctl disable firewalld关闭selinux sed -i s/enforcing/disabled/ /etc/selinux/config && setenforce 0centos7服务器,通过yum安装rsyslog yum -y install rsysl…

【uniapp】六格验证码输入框实现

效果图 代码实现 <view><view class"tips">已发送验证码至<text class"tips-phone">{{ phoneNumber }}</text></view><view class"code-input-wrap"><input class"code-input" v-model"…

AI:75-基于生成对抗网络的虚拟现实场景增强

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

[量化投资-学习笔记008]Python+TDengine从零开始搭建量化分析平台-CCI和ATR

目录 1. 指标简介CCIATR 2. 程序编写题外话 1. 指标简介 将这两个指标放在一起&#xff0c;一方面是因为这两个指标都属于摆动指数&#xff0c;可以反应市场的活跃度。 另一方面是因为CCI和ATR与之前提到的EMA,MACD,布林带的三个指标的计算基础不同。之前的三个指标都是以收盘…

坐标系转换(仅作记载)

一.极坐标转换为普通坐标系 参考&#xff1a;极坐标方程与直角坐标方程的互化 - 知乎 (zhihu.com) 公式&#xff1a;&#xff08;无需考虑象限引起的正负问题&#xff09; 普通坐标系转换为极坐标系 参考&#xff1a; 极坐标怎么与直角坐标系相互转化&#xff1f; - 知乎 (zh…

Docker本地镜像发布到阿里云或私有库

本地镜像发布到阿里云流程 &#xff1a; 1.自己生成个要传的镜像 2.将本地镜像推送到阿里云: 阿里云开发者平台:开放云原生应用-云原生&#xff08;Cloud Native&#xff09;-云原生介绍 - 阿里云 2.1.创建仓库镜像&#xff1a; 2.1.1 选择控制台&#xff0c;进入容器镜像服…

如何在Linux上部署1Panel运维管理面板并远程访问内网进行操作

文章目录 前言1. Linux 安装1Panel2. 安装cpolar内网穿透3. 配置1Panel公网访问地址4. 公网远程访问1Panel管理界面5. 固定1Panel公网地址 前言 1Panel 是一个现代化、开源的 Linux 服务器运维管理面板。高效管理,通过 Web 端轻松管理 Linux 服务器&#xff0c;包括主机监控、…

广和通5G模组FM650助力阿里云打造无影魔方Pro

随着云基础设施的完善及云电脑体验的不断优化&#xff0c;越来越多的个人和企业选择无影云电脑进行办公。基于云原生的云网端技术架构&#xff0c;无影云电脑相比传统PC&#xff0c;具有弹性、安全、保障个人数据等产品优势。 10月31日&#xff0c;阿里云在杭州云栖大会上宣布…

RSA 2048位算法的主要参数N,E,P,Q,DP,DQ,Qinv,D分别是什么意思 哪个是通常所说的公钥与私钥 -安全行业基础篇5

非对称加密算法RSA 在RSA 2048位算法中&#xff0c;常见的参数N、E、P、Q、DP、DQ、Qinv和D代表以下含义&#xff1a; N&#xff08;Modulus&#xff09;&#xff1a;模数&#xff0c;是两个大素数P和Q的乘积。N的长度决定了RSA算法的安全性。 E&#xff08;Public Exponent&a…

原神游戏干货分享:探索璃月的宝箱秘密,提高游戏资源获取效率!

《原神》是一款备受玩家喜爱的开放世界冒险游戏&#xff0c;而在游戏中获取资源是提升角色实力的重要途径。在这篇实用干货分享中&#xff0c;我们将介绍一些探索璃月地区的宝箱秘密&#xff0c;帮助你提高游戏资源获取的效率。 首先&#xff0c;璃月地区的宝箱分为普通宝箱和精…