YOLOv8-seg 分割代码详解(一)Predict

前言

  本文从 U-Net 入手熟悉分割的简单方法,再看 YOLOv8 的方法。主要梳理 YOLOv8 的网络结构,以及 Predict 过程的后处理方法。

U-Net 代码地址:https://github.com/milesial/Pytorch-UNet
YOLOv8 代码地址:https://github.com/ultralytics/ultralytics
YOLOv8 官方文档:https://docs.ultralytics.com/

1. U-Net

1.1 网络结构

在这里插入图片描述

CBR
Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)

1.2 转置卷积

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None
)

H o u t = ( H i n − 1 ) × stride [ 0 ] − 2 × padding [ 0 ] + dilation [ 0 ] × ( kernel_size [ 0 ] − 1 ) + output_padding [ 0 ] + 1 H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 Hout=(Hin1)×stride[0]2×padding[0]+dilation[0]×(kernel_size[0]1)+output_padding[0]+1
W o u t = ( W i n − 1 ) × stride [ 1 ] − 2 × padding [ 1 ] + dilation [ 1 ] × ( kernel_size [ 1 ] − 1 ) + output_padding [ 1 ] + 1 W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 Wout=(Win1)×stride[1]2×padding[1]+dilation[1]×(kernel_size[1]1)+output_padding[1]+1

  • stride
    控制原图像素之间的填充量,数值为 stride − 1 \text{stride}-1 stride1
  • kernel_size,padding
    kernel_size 为转置卷积核大小,并且和 padding 一同决定原图四周填充量,数值为 kernel_size − padding − 1 \text{kernel\_size}-\text{padding}-1 kernel_sizepadding1
  • dilation
    控制卷积核采样点的间距(空洞卷积),默认为1,即最普通的卷积

注:转置卷积在卷积时的 stride 固定为1,output_padding 固定为0;而参数中设置的 stride、padding 用于控制卷积之前对输入的填充

kernel_size = 2 , stride = 2 , padding = 0 , H i n = 640 , W i n = 640 \text{kernel\_size}=2,\text{stride}=2,\text{padding}=0,H_{in}=640,W_{in}=640 kernel_size=2,stride=2,padding=0,Hin=640,Win=640 为例

  1. 像素间填充 2-1=1, 640 × 640 → 1279 × 1279 640\times640\to1279\times1279 640×6401279×1279
  2. 四周填充 2-0-1=1, 1279 × 1279 → 1281 × 1281 1279\times1279\to1281\times1281 1279×12791281×1281
  3. 2x2卷积, 1281 × 1281 → 1280 × 1280 1281\times1281\to1280\times1280 1281×12811280×1280

查看不同卷积的可视化:https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

1.3 Loss

单分类
loss = BCEWithLogitsLoss(P, Y) + dice_loss(sigmoid(P), Y)
多分类
loss = CrossEntropyLoss(P, Y) + dice_loss(softmax(P), one_hot(Y))

(1)BCEWithLogitsLoss
对于每个样本 l = − [ y l o g σ ( x ) + ( 1 − y ) l o g ( 1 − σ ( x ) ) ] l=-[ylog\sigma (x)+(1-y)log(1-\sigma (x))] l=[ylogσ(x)+(1y)log(1σ(x))],最后求均值

(2)dice_loss

l = 1 − sum ( 2 × P × Y ) sum ( P ) + sum ( Y ) l = 1-\frac{\text{sum}(2\times P\times Y)}{\text{sum}(P)+\text{sum}(Y)} l=1sum(P)+sum(Y)sum(2×P×Y)

这里的 Y Y Y 作为标签是固定的, P P P 通过让目标区域值靠近1提高分子值,背景区域靠近0降低分母值,即 P → Y P\to Y PY,从而降低loss

1.4 Predict

单分类
mask = sigmoid(P) > threshold
多分类
mask = P.argmax(dim=1)

2. YOLOv8-seg

2.1 网络结构

  结构图中数据按 yolov8m-seg 的 predict 过程绘制,输入图像为 1280x720,预处理时通过 LetterBox 对图像进行保长宽比缩放和 padding,使其长宽都能被最大下采样倍率32整除。在 train 过程中,输入大小统一为 640x640。

主干

在这里插入图片描述

CBS
Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
BatchNorm2d(48, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
SiLU(inplace=True)

C2f 模块
在这里插入图片描述
SPPF
在这里插入图片描述

Segment-head

分割
在这里插入图片描述
检测
在这里插入图片描述
注:DLF层中的卷积层参数是固定的,在这里是 torch.arange(16)

Anchor
  Anchor坐标是把特征图看做一个网格,每个像素边长为1,把每个格子的中心点坐标取出来。以 x0 (h=48,w=80) 为例,左上角坐标为 (0.5,0.5),右下角点为 (79.5,47.5)
  DLF的输出对应目标框左上角坐标和右下角坐标到Anchor坐标的距离,与Anchor融合并乘上对应的下采样倍率得到 dbox

lt, rb = dfl(box).chunk(2, dim=1)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb

2.2 预测

模型推理输出

Y:	[b,4,5040]
mc: [b,32,5040]
p:	[b,32,96,160]Y为检测结果,4对应检测框坐标
mc为分割结果,32对应分割的特征向量,通过和p做矩阵乘法可以转化成mask形式模型最终推理的输出preds包含两项
(1)torch.cat(y, mc], 1), 即检测和分割的结果, shape:[b,37,5040]
(2)包含3项的元组a. [x0, x1, x2], 即detect层的中间输出b. mc 	[b,32,5040]c. p	[b,32,96,160]

NMS

p = nms((20,37,5040), conf=0.25, iou=0.7, agnostic=False, max_det=300, nc=1)(1) 分类得分阈值筛选 class_scores > conf=0.25
[5040,37] --> [n1,37]
(2) 提取类别
[n1,37] --> [n1,38] (xyxy, cls_score, cls, 32)
(3) 若此时box数量大于 max_nms=30000, 选取 cls_score 较大的30000(4) 调库 torchvision.ops.nms(boxes, scores, iou_thres), 选取前 max_det=300[n1,38] --> [n2,38]
(5) nms-merge, 默认跳过

mask

masks = process_mask(protos,					模型输出p [b,32,96,160]masks_in=pred[:, 6:], 	nms结果的mask部分bboxes=pred[:, :4], 	nms结果的box部分shape=img.shape[2:], 	输入图像大小(384,640)upsample=True
)def process_mask(protos, masks_in, bboxes, shape, upsample=False):c, mh, mw = protos.shapeih, iw = shape"""矩阵乘法+sigmoid得到mask"""masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)"""比例变换"""downsampled_bboxes = bboxes.clone()downsampled_bboxes[:, 0] *= mw / iwdownsampled_bboxes[:, 2] *= mw / iwdownsampled_bboxes[:, 3] *= mh / ihdownsampled_bboxes[:, 1] *= mh / ih"""裁减掉box范围以外的值"""masks = crop_mask(masks, downsampled_bboxes)  # CHWif upsample:masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW"""按阈值0.5转为二值图mask"""return masks.gt_(0.5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/184243.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

B站双11,联手天猫暴涨2亿消费新势力

一直以来,手持高活跃、高粘性用户群体的B站是行业用来观察年轻人消费习惯的重要平台。以至于用户群体的不断壮大带动了B站的商业价值。如今B站的商业舞台越来越大,不断地向外界招手,欢迎更多品牌积极加入到这个千万年轻人聚集的内容社区。 2…

大数据疫情分析及可视化系统 计算机竞赛

文章目录 0 前言2 开发简介3 数据集4 实现技术4.1 系统架构4.2 开发环境4.3 疫情地图4.3.1 填充图(Choropleth maps)4.3.2 气泡图 4.4 全国疫情实时追踪4.6 其他页面 5 关键代码最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 大数据疫…

Web Worker:JS多线程的伪解药?

前言 在前端开发领域,JavaScript 的单线程限制一直是一个难以忽视的挑战。当谈到解决JavaScript的单线程限制时,HTML5引入的Web Worker被普遍认为是一剂解药💊。同时,业界中大量的文章也是聚焦于讨论web worker的神奇力量。然而&…

Android内存回收机制、GC算法及内存问题分析解决

Android内存回收机制、GC算法及内存问题分析解决 在Android开发中,Java内存回收和垃圾收集(GC)机制是确保应用程序高效运行的关键部分。针对不同对象存活率,Android平台采用了引用计数算法和可达性分析法来判定对象的可回收性&am…

RTC实时时钟——DS1302

DS1302目录 一、DS1302简介引脚定义与推荐电路 二、芯片手册1.操作寄存器的定义2.时序定义dc1302.cds1302.h 三、蓝桥杯实践 一、DS1302简介 RTC(Real Time Clock):实时时钟,是一种集成电路,通常称为时钟芯片。现在流行的串行时钟电路很多,如…

华为李鹏:到 2025 年智能算力需求将达到目前水平的 100 倍

在第十四届全球移动宽带论坛上,华为高级副总裁、运营商 BG 总裁李鹏表示,大模型为代表的 AI 应用发展带来对智能算力的爆发式需求。 李鹏在题为《加速 5G 商业正循环,拥抱更繁荣的 5.5G》的讲话中表示,「5G 已经走在商业成功的正确…

C# OpenCvSharp 去除字母后面的杂线

效果 项目 代码 using OpenCvSharp; using System; using System.Drawing; using System.Windows.Forms;namespace OpenCvSharp_Demo {public partial class frmMain : Form{public frmMain(){InitializeComponent();}string image_path "";private void Form1_Loa…

三国志14信息查询小程序(历史武将信息一览)制作更新过程05-后台接口的编写及调用

1,创建ASP.NET Web API项目 生成完毕,项目结构如下: 运行看一下: 2,后台接口编写 (1)在Models文件夹中新建一个sandata.cs文件(就是上篇中武将信息表的model文件) u…

伦敦金开户需要多少资金,有开户条件吗?

伦敦金(London Gold)是黄金市场中备受瞩目的投资种类之一,无论是专业投资者还是新手,都对伦敦金感兴趣。但关于开户需要多少资金,以及是否有特定的开户条件,这些问题可能会让一些新手投资者感到困惑。 首先…

notepad++搜索结果窗口不见了

1、使用notepad打开一个文件文件 2、ctrlf,打开搜索窗口,随便搜索一个内容 3、按F7,然后AltF7 切换焦点到Find result. 会有一个小窗口出现,内容是:还原,移动,大小等 4,点移动,使…

[答疑]校长出轨主任流程的业务建模

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 艳阳天 2023-10-27 19:45 我有点迷糊。校长出轨主任在酒店被拍到,不属于学校的业务流程,但闹出这种事对学校有很大影响。如果学校想用一个系统抓风纪&#xff…

论文阅读—— BiFormer(cvpr2023)

论文:https://arxiv.org/abs/2303.08810 github:GitHub - rayleizhu/BiFormer: [CVPR 2023] Official code release of our paper "BiFormer: Vision Transformer with Bi-Level Routing Attention" 一、介绍 1、要解决的问题:t…

OpenLayers入门,OpenLayers加载离线xyz瓦片地图并显示离线鹰眼控件

专栏目录: OpenLayers入门教程汇总目录 前言 本章介绍如何使用OpenLayers加载离线xyz瓦片地图图层,并显示离线xyz瓦片的鹰眼控件。 本章是综合案例,涉及到两块内容,一个是离线瓦片地图加载,二个是鹰眼控件,拆分的参考文章如下: OpenLayers入门,OpenLayers地图鹰眼控…

Java面试题(每天10题)-------连载(26)

目录 多线程篇 1、什么是FutureTask? 2、什么是同步容器和并发容器的实现? 3、什么是多线程的上下文切换? 4、ThreadLocal的设计理念与作用? 5、ThreadPool(线程池)用法与优势? 6、Concur…

智能文件改名:高效复制并删除冗余,简化文件管理“

在繁杂的电脑文件世界中,如何高效地管理文件成为了许多人的难题。为了解决这一难题,我们推出了一款智能文件改名工具,它能够轻松复制文件并删除目标文件夹中的冗余文件,让您的文件管理更加高效便捷。 第一步,我们要打…

【网络协议】聊聊HTTPDNS如何工作的

传统 DNS 存在哪些问题? 域名缓存问题 我们知道CND会进行域名解析,但是由于本地会进行缓存对应的域名-ip地址,所以可能出现过期数据的情况。 域名转发问题 出口 NAT 问题 域名更新问题 解析延迟问题 因为在解析DNS的时候,需要进行…

Flink--Data Source 介绍

Data Source 简介 Flink 做为一款流式计算框架,它可用来做批处理,即处理静态的数据集、历史的数据集;也可以用来做流处理,即实时的处理些实时数据流,实时的产生数据流结果,只要数据源源不断的过来&#xff…

Spring的缓存机制-循环依赖

群公告 Java每日大厂面试题: 1、Spring 是如何解决循环依赖? 答案:三级缓存,简单来说,A创建过程中需要B,于是A将自己放到三级缓存里面,去实例化B,B实例化的时候发现需要…

【AICFD案例教程】进气歧管分析

AICFD是由天洑软件自主研发的通用智能热流体仿真软件,用于高效解决能源动力、船舶海洋、电子设备和车辆运载等领域复杂的流动和传热问题。软件涵盖了从建模、仿真到结果处理完整仿真分析流程,帮助工业企业建立设计、仿真和优化相结合的一体化流程&#x…

CSS时间线样式

css实现时间线样式,效果如下图: 一、CSS代码 .timeline {padding-left: 5px} .timeline-item { position: relative;padding-bottom: 20px;} .timeline-axis {position: absolute;left: -5px;top: 0;z-index: 10;width: 20px;height: 20px;line-he…