yolov8obb角度预测原理解析

预测头

ultralytics/nn/modules/head.py

class OBB(Detect):"""YOLOv8 OBB detection head for detection with rotation models."""def __init__(self, nc=80, ne=1, ch=()):"""Initialize OBB with number of classes `nc` and layer channels `ch`."""super().__init__(nc, ch)self.ne = ne  # number of extra parametersc4 = max(ch[0] // 4, self.ne)self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""bs = x[0].shape[0]  # batch sizeangle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2)  # OBB theta logits# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.angle = (angle.sigmoid() - 0.25) * math.pi  # [-pi/4, 3pi/4]# angle = angle.sigmoid() * math.pi / 2  # [0, pi/2]if not self.training:self.angle = anglex = Detect.forward(self, x)if self.training:return x, angle# return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

forward 输入值
在这里插入图片描述

self.cv4网路结构

ModuleList((0): Sequential((0): Conv((conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): Conv((conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1)))(1): Sequential((0): Conv((conv): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): Conv((conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1)))(2): Sequential((0): Conv((conv): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): Conv((conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1)))

angle维度14,1,8400

损失函数

pred_angle = pred_angle.permute(0, 2, 1).contiguous()
维度变为14 8400 1

将预测结果转为bboxes

pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle)  # xyxy, (b, h*w, 4)

计算回归损失

loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)

这里的bbox_loss指的是:

self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)

接来下看RotatedBboxLoss

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl

两个旋转矩形如何计算IOU:

def probiou(obb1, obb2, CIoU=False, eps=1e-7):"""Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.Args:obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): A tensor of shape (N, ) representing obb similarities."""x1, y1 = obb1[..., :2].split(1, dim=-1)x2, y2 = obb2[..., :2].split(1, dim=-1)a1, b1, c1 = _get_covariance_matrix(obb1)a2, b2, c2 = _get_covariance_matrix(obb2)t1 = (((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.25t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5t3 = (((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))/ (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)+ eps).log() * 0.5bd = (t1 + t2 + t3).clamp(eps, 100.0)hd = (1.0 - (-bd).exp() + eps).sqrt()iou = 1 - hdif CIoU:  # only include the wh aspect ratio partw1, h1 = obb1[..., 2:4].split(1, dim=-1)w2, h2 = obb2[..., 2:4].split(1, dim=-1)v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)with torch.no_grad():alpha = v / (v - iou + (1 + eps))return iou - v * alpha  # CIoUreturn iou

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/366812.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.k8s:架构,组件,基础概念

目录 一、k8s了解 1.什么是k8s 2.为什么要k8s (1)部署方式演变 (2)k8s作用 (3)Mesos,Swarm,K8S三大平台对比 二、k8s架构、组件 1.k8s架构 2.k8s基础组件 3.k8s附加组件 …

深入理解ThreadLocal原理

以下内容首发于我的个人网站,来这里看更舒适:https://riun.xyz/work/9898775 ThreadLocal是一种用于实现线程局部变量的机制,它允许每个线程有自己独立的变量,从而达到了线程数据隔离的目的。 基于JDK8 使用 通常在项目中是这样…

JAVA连接FastGPT实现流式请求SSE效果

FastGPT 是一个基于 LLM 大语言模型的知识库问答系统,提供开箱即用的数据处理、模型调用等能力。同时可以通过 Flow 可视化进行工作流编排,从而实现复杂的问答场景! 一、先看效果 真正实流式请求,SSE效果,SSE解释&am…

全球首款商用,AI为视频自动配音配乐产品上线

近日,海外推出了一款名为Resona V2A的产品,这是全球首款商用视频转音频 (V2A) 技术产品。这项突破性技术利用AI,仅凭视频数据即可自动生成高质量、与上下文相关的音频,包括声音设计、音效、拟音和环境音,为电影制作人、…

Ubuntu20.04安装Prometheus监控系统

环境准备: 服务器名称内网IP公网IPPrometheus服务器192.168.0.23047.119.21.167Grafana服务器192.168.0.23147.119.22.8被监控服务器192.168.0.23247.119.22.82 更改主机名方便辨认 hostnamectl set-hostname prometheus hostnamectl set-hostname grafana hostn…

Springboot下使用Redis管道(pipeline)进行批量操作

之前有业务场景需要批量插入数据到Redis中,做的过程中也有一些感悟,因此记录下来,以防忘记。下面的内容会涉及到 分别使用for、管道处理批量操作,比较其所花费时间。 分别使用RedisCallback、SessionCallback进行Redis pipeline …

【BES2500x系列 -- RTX5操作系统】深入探索CMSIS-RTOS RTX -- 同步与通信篇 -- 消息队列和邮箱处理 --(四)

💌 所属专栏:【BES2500x系列】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &#x1f49…

【Node-RED 4.0.2】4.0版本新增特性(官方版)

二、重要功能 *1.时间戳格式改进 过去,node-red 只提供了 最原始的 timestamp 的格式(1970-01-01 ~ now) 但是现在,额外增加了 2 种格式: ISO 8601 -A COMMON FORMAT(YYYY-MM-DDTHH:mm:ss:sssZ&#xff…

Linux环境安装配置nginx服务流程

Linux环境的Centos、麒麟、统信操作系统安装配置nginx服务流程操作: 1、官网下载 下载地址 或者通过命令下载 wget http://nginx.org/download/nginx-1.20.2.tar.gz 2、上传到指定的服务器并解压 tar -zxvf nginx-1.20.1.tar.gzcd nginx-1.20.1 3、编译并安装到…

阿里Nacos下载、安装(保姆篇)

文章目录 Nacos下载版本选择Nacos安装Windows常见问题解决 更多相关内容可查看 Nacos下载 Nacos官方下载地址:https://github.com/alibaba/nacos/releases 码云拉取(如果国外较慢或者拉取超时可以试一下国内地址) //国外 git clone https:…

[数据集][目标检测]桥梁检测数据集VOC+YOLO格式1116张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):1116 标注数量(xml文件个数):1116 标注数量(txt文件个数):1116 标注…

【RabbitMQ实战】Springboot 整合RabbitMQ组件,多种编码示例,带你实践 看完这一篇就够了

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、对RabbitMQ管理界面深入了解1、在这个界面里面我们可以做些什么? 二、编码练习(1)使用direct exchange(直连型交换机)&a…

snowflake 不再是个数据仓库公司了

标题先上结论,为啥这么认为,且听接下来道来。 snowflake 非常成功,开创了云数仓先河,至今在数仓架构上也是相对比较先进的,国内一堆模仿的公司,传统上我们会认为 snowflake 肯定是一家数据仓库公司。不过最…

3D Gaussian Splatting代码中的forward和backward两个文件代码解读

3dgs代码前向传播部分 先来讨论一下glm,因为定义变量的时候用到了这个。 glm的解释 glm 是指 OpenGL Mathematics,这是一个针对图形编程的数学库。它的全称是 OpenGL Mathematics (GLM),主要用于 OpenGL 的开发。这个库是基于 C 的模板库&…

什么是CC攻击,如何防止网站被CC攻击的方法

前言 “CC攻击的原理就是攻击者控制某些主机不停地发大量数据包给对方服务器造成服务器资源耗尽,一直到宕机崩溃。” 什么是CC攻击? CC攻击前身是一个名为Fatboy的攻击程序,而之所以后来人们会称之为CC,也叫HTTP-FLOOD&#xff…

【AI提升】如何使用大模型:本机离线和FastAPI服务调用

大模型本身提供的功能,类似于windows中的一个exe小工具,我们可以本机离线调用然后完成具体的功能,但是别的机器需要访问这个exe是不可行的。常见的做法就是用web容器封装起来,提供一个http接口,然后接口在后端调用这个…

lodash.js 工具库

lodash 是什么? Lodash是一个流行的JavaScript实用工具库,提供了许多高效、高兼容性的工具函数,能够方便地处理集合、字符串、数值、函数等多种数据类型,大大提高工作效率。 lodash官网 文档参见:Lodash Documentation lodash 在Vue中怎么使用? 1、首先安装 lodash np…

JDK动态代理-AOP编程

AOPTest.java,相当于main函数,经过代理工厂出来的Hello类对象就不一样了,这是Proxy.newProxyInstance返回的对象,会hello.addUser会替换为invoke函数,比如这里的hello.addUser("sun", "13434");会…

Python 作业题1 (猜数字)

题目 你要根据线索猜出一个三位数。游戏会根据你的猜测给出以下提示之一:如果你猜对一位数字但数字位置不对,则会提示“Pico”;如果你同时猜对了一位数字及其位置,则会提示“Fermi”;如果你猜测的数字及其位置都不对&…

无人机生态环境监测、图像处理与GIS数据分析综合实践技术应用

构建“天空地”一体化监测体系是新形势下生态、环境、水文、农业、林业、气象等资源环境领域的重大需求,无人机生态环境监测在一体化监测体系中扮演着极其重要的角色。通过无人机航空遥感技术可以实现对地表空间要素的立体观测,获取丰富多样的地理空间数…