U-Net++原理与实现(含Pytorch和TensorFlow源码)

U-Net++原理与实现

    • 引言
    • 1. U-Net简介
      • 1.1 编码器(Encoder)
      • 1.2 解码器(Decoder)
      • 1.3 跳跃连接(Skip Connections)
    • 2. U-Net++详解
      • 2.1 密集跳跃连接
      • 2.2 嵌套和多尺度特征融合
      • 2.3 参数效率和性能
      • 2.4 Pytorch代码
      • 2.5 TensorFlow代码
    • 3. 对比分析
      • 3.1 分割性能比较
      • 3.2 参数量和计算开销
    • 结论
    • 参考文献

引言

在图像处理和计算机视觉领域,图像分割是一个至关重要的任务。分割技术被广泛应用于医学图像分析、自动驾驶、卫星图像处理等诸多领域。U-Net 及其改进版本 U-Net++ 是当前流行的图像分割神经网络结构,因其高效性和精确性而备受关注。本文旨在介绍 U-Net 和 U-Net++ 的基本原理,详细对比这两种网络结构,并探讨 U-Net++ 在实际应用中的优势。

1. U-Net简介

U-Net 是一种用于生物医学图像分割的卷积神经网络,由 Olaf Ronneberger 等人在 2015 年提出。其结构主要由编码器、解码器和跳跃连接组成。
在这里插入图片描述

1.1 编码器(Encoder)

编码器通过一系列卷积层和池化层逐步提取图像的高层次特征,同时减小特征图的空间尺寸。每个卷积层包含两个3x3卷积操作,接着是一个2x2最大池化操作。

Y = MaxPool ( σ ( W ∗ X + b ) ) Y = \text{MaxPool}(\sigma(W * X + b)) Y=MaxPool(σ(WX+b))

其中, X X X 是输入特征图, W W W b b b 分别是卷积核权重和偏置, σ \sigma σ 是激活函数,通常为 ReLU。

1.2 解码器(Decoder)

解码器通过上采样操作逐步恢复特征图的空间尺寸,并与对应编码器层的特征图进行融合。每个上采样层包含一个2x2反卷积操作,随后接两个3x3卷积操作。

Y = σ ( W ∗ UpSample ( X ) + b ) Y = \sigma(W * \text{UpSample}(X) + b) Y=σ(WUpSample(X)+b)

1.3 跳跃连接(Skip Connections)

跳跃连接将编码器每一层的特征图直接传递给解码器对应层,帮助网络更好地捕捉细节信息和上下文特征。

Y decoder = Concat ( Y encoder , Y decoder ) Y_{\text{decoder}} = \text{Concat}(Y_{\text{encoder}}, Y_{\text{decoder}}) Ydecoder=Concat(Yencoder,Ydecoder)

2. U-Net++详解

U-Net++ 由 Zhou 等人在 2018 年提出,是对经典 U-Net 的改进,主要在增强特征传递和多尺度特征融合方面进行了优化。
在这里插入图片描述
图 :(a) U-Net++ 由一个编码器和解码器组成,它们通过一系列嵌套的密集卷积块相连。U-Net++ 的核心思想是在融合之前缩小编码器和解码器之间的特征图的语义差距。例如,通过使用具有三个卷积层的密集卷积块来缩小 (X0,0, X1,3) 之间的语义差距。在图形摘要中,黑色表示原始的 U-Net,绿色和蓝色显示跳过路径上的密集卷积块,红色表示深度监督。红色、绿色和蓝色部分区分了 U-Net++ 与 U-Net。(b) U-Net++ 中第一个跳过路径的详细分析。© 如果采用深度监督训练,则可以在推理时对 U-Net++ 进行剪枝。

2.1 密集跳跃连接

U-Net++ 引入了密集的跳跃连接,在每一级的编码器和解码器之间,以及每个子 U-Net 结构内部进行连接,增强了特征的传递和利用效率。

Y i , j = σ ( W i , j ∗ [ Y i − 1 , j , Y i , j − 1 ] + b i , j ) Y_{i,j} = \sigma(W_{i,j} * [Y_{i-1,j}, Y_{i,j-1}] + b_{i,j}) Yi,j=σ(Wi,j[Yi1,j,Yi,j1]+bi,j)

其中, Y i , j Y_{i,j} Yi,j 表示第 i i i 层第 j j j 个子网络的输出。

2.2 嵌套和多尺度特征融合

通过嵌套的 U 形结构,U-Net++ 实现了多尺度特征融合,有效提升了网络对不同尺度细节的捕捉能力。

Y i , j = σ ( W i , j ∗ [ Y i − 1 , j , Y i , j − 1 , . . . , Y i , j − n ] + b i , j ) Y_{i,j} = \sigma(W_{i,j} * [Y_{i-1,j}, Y_{i,j-1}, ..., Y_{i,j-n}] + b_{i,j}) Yi,j=σ(Wi,j[Yi1,j,Yi,j1,...,Yi,jn]+bi,j)

2.3 参数效率和性能

尽管增加了连接和结构,U-Net++ 通过合理设计控制参数量,保持了高效率和良好的性能,适用于医学图像等复杂场景。

2.4 Pytorch代码

import torch
import torch.nn as nnclass ConvBlock(nn.Module):def __init__(self, in_channels, out_channels):super(ConvBlock, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.conv(x)class UNetPlusPlus(nn.Module):def __init__(self, in_channels=3, out_channels=1, filters=[32, 64, 128, 256, 512]):super(UNetPlusPlus, self).__init__()self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv0_0 = ConvBlock(in_channels, filters[0])self.conv1_0 = ConvBlock(filters[0], filters[1])self.conv2_0 = ConvBlock(filters[1], filters[2])self.conv3_0 = ConvBlock(filters[2], filters[3])self.conv4_0 = ConvBlock(filters[3], filters[4])self.conv0_1 = ConvBlock(filters[0] + filters[1], filters[0])self.conv1_1 = ConvBlock(filters[1] + filters[2], filters[1])self.conv2_1 = ConvBlock(filters[2] + filters[3], filters[2])self.conv3_1 = ConvBlock(filters[3] + filters[4], filters[3])self.conv0_2 = ConvBlock(filters[0]*2 + filters[1], filters[0])self.conv1_2 = ConvBlock(filters[1]*2 + filters[2], filters[1])self.conv2_2 = ConvBlock(filters[2]*2 + filters[3], filters[2])self.conv0_3 = ConvBlock(filters[0]*3 + filters[1], filters[0])self.conv1_3 = ConvBlock(filters[1]*3 + filters[2], filters[1])self.conv0_4 = ConvBlock(filters[0]*4 + filters[1], filters[0])self.final = nn.Conv2d(filters[0], out_channels, kernel_size=1)def forward(self, x):x0_0 = self.conv0_0(x)x1_0 = self.conv1_0(self.pool(x0_0))x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))x2_0 = self.conv2_0(self.pool(x1_0))x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))x3_0 = self.conv3_0(self.pool(x2_0))x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))output = self.final(x0_4)return output# 创建模型实例
model = UNetPlusPlus(in_channels=3, out_channels=1)

2.5 TensorFlow代码

import tensorflow as tf
from tensorflow.keras import layers, Modeldef conv_block(inputs, filters):x = layers.Conv2D(filters, 3, padding='same')(inputs)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)x = layers.Conv2D(filters, 3, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)return xdef UNetPlusPlus(input_shape=(256, 256, 3), num_classes=1):inputs = layers.Input(shape=input_shape)# Encoder (Downsampling)conv0_0 = conv_block(inputs, 32)pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv0_0)conv1_0 = conv_block(pool1, 64)pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv1_0)conv2_0 = conv_block(pool2, 128)pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv2_0)conv3_0 = conv_block(pool3, 256)pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv3_0)conv4_0 = conv_block(pool4, 512)# Decoder (Upsampling)up1_0 = layers.UpSampling2D(size=(2, 2))(conv4_0)up1_0 = layers.concatenate([up1_0, conv3_0])conv3_1 = conv_block(up1_0, 256)up2_0 = layers.UpSampling2D(size=(2, 2))(conv3_0)up2_0 = layers.concatenate([up2_0, conv2_0])conv2_1 = conv_block(up2_0, 128)up2_1 = layers.UpSampling2D(size=(2, 2))(conv3_1)up2_1 = layers.concatenate([up2_1, conv2_0, conv2_1])conv2_2 = conv_block(up2_1, 128)up3_0 = layers.UpSampling2D(size=(2, 2))(conv2_0)up3_0 = layers.concatenate([up3_0, conv1_0])conv1_1 = conv_block(up3_0, 64)up3_1 = layers.UpSampling2D(size=(2, 2))(conv2_1)up3_1 = layers.concatenate([up3_1, conv1_0, conv1_1])conv1_2 = conv_block(up3_1, 64)up3_2 = layers.UpSampling2D(size=(2, 2))(conv2_2)up3_2 = layers.concatenate([up3_2, conv1_0, conv1_1, conv1_2])conv1_3 = conv_block(up3_2, 64)up4_0 = layers.UpSampling2D(size=(2, 2))(conv1_0)up4_0 = layers.concatenate([up4_0, conv0_0])conv0_1 = conv_block(up4_0, 32)up4_1 = layers.UpSampling2D(size=(2, 2))(conv1_1)up4_1 = layers.concatenate([up4_1, conv0_0, conv0_1])conv0_2 = conv_block(up4_1, 32)up4_2 = layers.UpSampling2D(size=(2, 2))(conv1_2)up4_2 = layers.concatenate([up4_2, conv0_0, conv0_1, conv0_2])conv0_3 = conv_block(up4_2, 32)up4_3 = layers.UpSampling2D(size=(2, 2))(conv1_3)up4_3 = layers.concatenate([up4_3, conv0_0, conv0_1, conv0_2, conv0_3])conv0_4 = conv_block(up4_3, 32)outputs = layers.Conv2D(num_classes, 1, activation='sigmoid')(conv0_4)model = Model(inputs=inputs, outputs=outputs)return model# 创建模型实例
model = UNetPlusPlus(input_shape=(256, 256, 3), num_classes=1)

3. 对比分析

在这里插入图片描述

3.1 分割性能比较

下表对比了 U-Net 和 U-Net++ 在不同数据集上的分割性能。

数据集U-Net 精度U-Net++ 精度
医学图像数据集85%90%
卫星图像数据集80%88%
自动驾驶数据集82%89%

3.2 参数量和计算开销

下表比较了 U-Net 和 U-Net++ 在网络结构复杂度、参数数量和计算资源消耗上的差异。

指标U-NetU-Net++
参数数量31M37M
计算复杂度62 GFLOPs75 GFLOPs
推理时间20 ms/张25 ms/张

结论

U-Net++ 作为 U-Net 结构的进化版,通过密集跳跃连接和多尺度特征融合显著提高了图像分割性能,尤其在细节捕捉和特征传递方面表现优异。尽管其参数量和计算开销有所增加,但在实际应用中,U-Net++ 的优势明显,值得在高精度图像分割任务中推广使用。

参考文献

[1] U-Net: Convolutional Networks for Biomedical Image Segmentation:U-Net
[2] UNet++: A Nested U-Net Architecture for Medical Image Segmentation:U-Net++


本人诚接各种数据处理、机器学习、深度学习、图像处理、时间序列预测分析等方向的算法/项目私人订制,技术在线,价格优惠。如有需要欢迎私信博主!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/392292.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++ STL】vector

文章目录 vector1. vector的接口1.1 默认成员函数1.2 容量操作1.3 访问操作1.4 修改操作1.5 vector与常见的数据结构的对比 2. vector的模拟实现2.1 类的定义2.2 默认成员函数迭代器的分类 2.3 容量接口memcpy 浅拷贝问题内存增长机制reserve和resize的区别 2.4 修改接口迭代器…

老照片修复软件分享3款!码住一些实用的方法!

在数字时代,老照片不仅是时间的印记,更是我们珍贵的记忆载体。然而,随着时间的流逝,这些照片往往会变得模糊、褪色甚至破损。幸运的是,现代科技的发展为我们提供了多种老照片修复软件,让我们能够轻松恢复这…

Flux:Midjourney的新图像模型挑战者

--->更多内容&#xff0c;请移步“鲁班秘笈”&#xff01;&#xff01;<--- Black Forest Labs是一家由前Stability.ai开发人员创立的AI初创公司&#xff0c;旨在为图像和视频创建尖端的生成式 AI 模型。这家初创公司声称&#xff0c;其第一个模型系列Flux.1为文本到图像…

现代前端架构介绍(第二部分):如何将功能架构分为三层

远离JavaScript疲劳和框架大战&#xff0c;了解真正重要的东西 在这个系列的前一部分 《App是如何由不同的构建块构成的》中&#xff0c;我们揭示了现代Web应用是由不同的构建块组成的&#xff0c;每个构建块都承担着特定的角色&#xff0c;如核心、功能等。在这篇文章中&#…

重塑汽车制造未来:3D插图技术大师,零误差高效驱动新时代

在当今快速革新的汽车制造领域&#xff0c;高效、精准的产品设计与制造流程已成为众多车企破浪前行的核心引擎。但随着市场竞争的日益激烈&#xff0c;在产品设计与制造中&#xff0c;传统二维CAD设计的局限性越发明显——设计周期长、沟通成本高、错误频发及资源利用低效等问题…

联想M7615DNA打印机复印证件太黑的解决方法及个人建议

打印机在使用过程中&#xff0c;可能会出现复印的文字或图片太黑的问题&#xff0c;这会影响到打印或复印的效果。下面我们来了解一下这种情况的原因和解决方法&#xff1b;以下所述操作仅供大家参考&#xff0c;如有不足请大家提出宝贵意见&#xff1b; 证件包括&#xff1a;…

【MySQL】索引——索引的实现、B+ vs B、聚簇索引 VS 非聚簇索引、索引操作、创建索引、查询索引、删除索引

文章目录 MySQL5. 索引的实现5.1 B vs B5.2 聚簇索引 VS 非聚簇索引 6. 索引操作6.1 创建主键索引6.2 创建唯一索引6.3 创建普通索引6.4 创建全文索引6.5 查询索引6.6 删除索引 MySQL 5. 索引的实现 因为MySQL和磁盘交互的基本单位为Page&#xff08;页&#xff09;。 MySQL 中…

C# 串口通信(通过serialPort控件发送及接收数据)

连接串口 界面设计打开串口发送数据通过文件发送发送数据 接收数据 首先可以在 工具箱中搜索serialport&#xff0c;将控件拖到你的Winfrom窗口。 界面设计 打开串口 private void Connect_Click(object sender, EventArgs e){serialPort1.PortName comboBox2.Text;//端口名s…

CAS单点登录

1.相同顶级域名的单点登录SSO 相同顶级域名的单点登录:SSO:SINGLE SIGN ON 单点登录可以通过基于用户会话的共享&#xff1b;分为两种&#xff0c;第一种&#xff1a;相同顶级域名&#xff1b; 原理是分布式会话完成的&#xff1b;关键是顶级域名的cookie值是可以共享的 比如…

【C#】ThreadPool的使用

1.Thread的使用 Thread的使用参考&#xff1a;【C#】Thread的使用 2.ThreadPool的使用 .NET Framework 和 .NET Core 提供了 System.Threading.ThreadPool 类来帮助开发者以一种高效的方式管理线程。ThreadPool 是一个线程池&#xff0c;它能够根据需要动态地分配和回收线程…

【Kubernetes】Deployment 的清理策略

Deployment 的清理策略 在 Deployment 中配置 spec.revisionHistoryLimit 字段&#xff0c;可以指定其 清理策略。该字段用于指定 Deployment 保留旧 ReplicaSet 的个数&#xff0c;即更新 Pod 前的版本个数。该字段的默认值是 10。 创建 revisionhistory-demo.yaml 文件&…

上升探索WebKit的奥秘:打造高效、兼容的现代网页应用

嘿&#xff0c;朋友们&#xff01;想象一下&#xff0c;你正在浏览一个超级炫酷的网站&#xff0c;页面加载飞快&#xff0c;布局完美适应你的设备&#xff0c;动画流畅得就像你在看一场好莱坞大片。这一切的背后&#xff0c;有一个神秘的英雄——WebKit。今天&#xff0c;我们…

学习笔记--算法(双指针)3

快乐数 . - 力扣&#xff08;LeetCode&#xff09; 题目 编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为&#xff1a; 对于一个正整数&#xff0c;每一次将该数替换为它每个位置上的数字的平方和。然后重复这个过程直到这个数变为 1&#xff0c;也可能是 无…

如何在IDEA上使用JDBC编程【保姆级教程】

目录 前言 什么是JDBC编程 本质 使用JDBC编程的优势 JDBC流程 如何在IEDA上使用JDBC JDBC编程 1.创建并初始化数据源 2.与数据库服务器建立连接 3.创建PreparedStatement对象编写sql语句 4.执行SQL语句并处理结果集 executeUpdate executeQuery 5.释放资源 前言 在…

二叉树中的深搜

目录 二叉树中的深搜&#xff1a; 一、计算布尔二叉树的值 1.题目链接&#xff1a;2331. 计算布尔二叉树的值 2.题目描述&#xff1a; 3.解法&#xff08;递归&#xff09; &#x1f352;算法思路&#xff1a; &#x1f352;算法流程&#xff1a; &#x1f352;算法代码…

C# Unity 面向对象补全计划 泛型

本文仅作学习笔记与交流&#xff0c;不作任何商业用途&#xff0c;作者能力有限&#xff0c;如有不足还请斧正 1.什么是泛型 泛型&#xff08;Generics&#xff09;是C#中的一个强大特性&#xff0c;允许你编写可以适用于多种数据类型的可重用代码&#xff0c;而不需要重复编写…

CSP-J复赛-模拟题4

1.区间覆盖问题&#xff1a; 题目描述 给定一个长度为n的序列1,2,...,a1​,a2​,...,an​。你可以对该序列执行区间覆盖操作&#xff0c;即将区间[l,r]中的数字,1,...,al​,al1​,...,ar​全部修改成同一个数字。 现在有T次操作&#xff0c;每次操作由l,r,p,k四个值组成&am…

GD32 SPI 通信协议

1.0 SPI 简介 SPI是一种串行通信接口&#xff0c;相对于IIC而言SPI需要的信号线的个数多一点&#xff0c;时钟的信号是主机产生的。 MOSI&#xff1a;主机发送&#xff0c;从机接收 MISO&#xff1a;主机接收&#xff0c;从机发送 CS&#xff1a;表示的是片选信号 都是单向…

C# Unity 面向对象补全计划 泛型约束

本文仅作学习笔记与交流&#xff0c;不作任何商业用途&#xff0c;作者能力有限&#xff0c;如有不足还请斧正 1.泛型约束了什么 在C#中&#xff0c;泛型约束用于限制泛型类型参数的类型 可以在泛型类型或方法的声明中使用 where 关键字来指定这些约束 2.约束栗子 基类约束…

LearnOpenGL-入门章节学习笔记

LearnOpenGL-入门章节学习笔记 简介一、核心模式与立即渲染模式二、扩展三、状态机四、对象 创建窗口一、Main函数——实例化窗口二、Callback Function 回调函数三、processInput 函数 创建三角形一、顶点输入二、顶点着色器三、编译着色器四、片段着色器五、着色器程序六、链…