YOLO obb全流程

内容:xanylabeling 数据标注工具;pytorch(python);yolo-obb 模型

一、数据集

1、数据集工具xanylabeling的安装

(详细配置与使用方法参考:X-Anylabeling自动标注软件安装使用教程含conda环境搭建安装(详细教程)_x-anylabeling安装教程-CSDN博客)

2、xanylabeling的使用

        直接图像中标注矩形框(在选中该矩形框后,使用 Z\X\C\V四个按钮设置旋转角度,围绕中心点)。

        打开图片目录并对图片操作保存后,在同目录下出现同名的json文件,需要使用json转化成yolo格式的txt文件。

3、数据集搭建

        保持图片名称与txt名称一致;整个数据集格式如下,images 存放图片文件,labels的存放txt文件(txt内容为:0 0.530889 0.337543 0.235365 0.369421 0.224447 0.224803 0.519971 0.192925,九个分别表示 类别 第一顶点x和y 第二顶点x和y 第三顶点x和y 第四顶点x和y);

二、模型

1、下载yolo-obb

        ultralytics/ultralytics: Ultralytics YOLO11 🚀

2、修改配置文件

        cfg/datasets文件夹下设置yaml文件:(有些时候相对路径出现问题,所以本图直接使用绝对路径)

        cfg/models文件夹下找到对应版本,修改obb的yaml文件,修改nc(也可能不用修改,程序运行中自动获取datasets的yaml设置类别),若是需要修改模型结构,也在这个文件

3、终端运行

        下载代码后,出现两个层级的ultralytics文件夹,第一层表示代码根目录,包括docker等部署方法的内容,第二层ultralytics才是我们所需要的代码部分。

        将数据集放在第一层ultralytics内后,打开终端运行,直接使用 yolo obb train data=cfg\datasets\dota8.yaml model=cfg\models\v8\yolov8s-obb.yaml epochs=100 batch=4 device=0 imgsz=640 这句命令运行,data表示数据文件,model表示加载模型;

        提示:1、运行过程中,出现提示未查询到数据集,自动下载dota8,表示数据文件没有配置好。

                   2、出现下载yolo11n.pt表示正常运行,需要借助yolo11n进行提高能力

三、模型使用

from ultralytics import YOLO
# 加载训练好的模型
model = YOLO("cfg/models/v8/yolov8s-obb.yaml")
model.load("best.pt")   # 训练后的模型
# 图片地址
image_path = r"..\dataset\test\images"
import os
for img in os.listdir(image_path):# if img == "35.jpg":img_path = os.path.join(image_path,img)# 在图片上运行推理results = model(image_path) # 推理图片r =  results[0]image = r.plot()  # 获取图片以及检测框等数据# 显示结果图像cv2.imshow('Image with Center Points', image )cv2.waitKey(0)cv2.destroyAllWindows()

四、修改模型(可选)

        混合注意力为例

1、修改cfg/models/v8/yolov8-obb.yaml

         在对应版本的yaml的文件中添加模型(也可以删除或变更)

        添加混合注意力层次(-1表示上层输出通道作为本层输入通道、1表示模块数量、CBAM表示模块名称、[1024,7]表示输入的参数,1024是本层输出通道数,7表示卷积核尺寸)。

2、修改nn文件夹

        第一处:一般而言在尝试用的模块(CBAM\DWCONV等卷积)在nn/conv.py内存在,若在nn文件夹下的其他py均未定义需要使用模块,那么需要自行添加(根据其功能在对应文件内添加,或者直接新增文件)

        添加模块:使用class定义该模块,参数与yaml文件保持一致

        第二处:在 nn/_init_.py 文件内添加模块导入导出(图表示该模块位于conv文件,以此类推)

        第三处:在/task.py 文件中 from ultralytics.nn.modules import ()内添加CBAM模块;在base_modules = frozenset()中添加CBAM;在解析模块中添加参数解析,将第一个参数设置为输出通道和输入通道,第二个参数设置为卷积核大小

 

 3、直接训练

        修改完毕后继续训练,在训练输出内容中查看模型结构是否得到应用。

        验证方式与之前一致。

五、具体任务(参考)

        此次任务是需要进行检测旋转框,并根据检测内容分析物体方向(我使用yoloobb+opencv的方法),任务目标:检测巧克力的位置并判断其视觉方向。

        思路

索取yolo检测结果

results = model(image_path) # 推理图片
r =  results[0]
image = r.plot()  # 获取图片以及检测框等数据
boxes = r.obb.xyxyxyxy.detach() # 获取四个角点坐标

 根据检测的四个点进行判断长度和中心点

points = []
for box in boxes:zreo_point = box[0]first_point = box[1]second_point = box[2]three_point = box[3]loss0 = torch.sum(torch.pow(zreo_point-first_point, 2))loss1 = torch.sum(torch.pow(first_point-second_point, 2))if loss0<=loss1:points.append((zreo_point+first_point)/2)points.append((second_point+three_point)/2)else:   points.append((first_point+second_point)/2)points.append((zreo_point+three_point)/2)

画出宽中心点连接中心的两个线段,根据图像色彩判断方向(二值化的腐蚀与膨胀),保留目标线段并画在检测结果的图像数据中。

def get_line_colors(image, point1, point2):# 获取坐标和中心点x1, y1 = point1x2, y2 = point2x_center = int((x1 + x2)/2)y_center = int((y1 + y2)/2)center_point = (x_center, y_center)# 缩放点位 让宽中心点向物体中心点靠近vector1 = np.array(point1) - np.array(center_point)vector2 = np.array(point2) - np.array(center_point)# 缩放向量scaled_vector1 = vector1 * 3/4 scaled_vector2 = vector2 * 3/4# 计算新的终点坐标point1 = np.array(center_point) + scaled_vector1point2 = np.array(center_point) + scaled_vector2# 获取新的坐标和中心点x1, y1 = point1x2, y2 = point2x_center = int((x1 + x2)/2)y_center = int((y1 + y2)/2)# 使用 OpenCV 的 line iterator 来获取沿线的像素,判断像素的变化得到视觉的方向line1_iterator = cv2.line(np.zeros_like(image), (int(x1), int(y1)), (int(x_center), int(y_center)), color=1, thickness=2)line2_iterator = cv2.line(np.zeros_like(image), (int(x2), int(y2)), (int(x_center), int(y_center)), color=1, thickness=2)colors1 = []colors2 = []# 图片处理mage = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)h, s, v = cv2.split(mage)# 增强S通道(饱和度)s_enhanced = cv2.equalizeHist(s)enhanced_image = cv2.merge((h, s_enhanced, v))BGR_image = cv2.cvtColor(enhanced_image, cv2.COLOR_HSV2BGR)GRAY_image = cv2.cvtColor(BGR_image, cv2.COLOR_BGR2GRAY)_, bin_image = cv2.threshold(GRAY_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)kernel = np.ones((7, 7), np.uint8)  # 创建一个5x5的方形结构元素bin_image = cv2.dilate(bin_image, kernel, iterations=5)bin_image = cv2.erode(bin_image, kernel, iterations=9)for point in zip(*np.where(line1_iterator)):colors1.append(bin_image[point[0], point[1]])for point in zip(*np.where(line2_iterator)):colors2.append(bin_image[point[0], point[1]])sum_color1 = np.average(colors1)sum_color2 = np.average(colors2) if sum_color1>sum_color2:cv2.arrowedLine(image, (int(x_center), int(y_center)), (int(x1), int(y1)), (255, 0, 0), 2)  # 标记连接点的线段else:cv2.arrowedLine(image, (int(x_center), int(y_center)), (int(x2), int(y2)), (255, 0, 0), 2)  # 标记连接点的线段return image

完整代码如下所示:

import numpy as np
import cv2 
import torch
from ultralytics import YOLO# 获取boxes的框信息
def point_list(boxes):points = []for box in boxes:zreo_point = box[0]first_point = box[1]second_point = box[2]three_point = box[3]loss0 = torch.sum(torch.pow(zreo_point-first_point, 2))loss1 = torch.sum(torch.pow(first_point-second_point, 2))if loss0<=loss1:points.append((zreo_point+first_point)/2)points.append((second_point+three_point)/2)else:   points.append((first_point+second_point)/2)points.append((zreo_point+three_point)/2)return points
# 根据两个线段判断正反情况
def get_line_colors(image, point1, point2):# 获取坐标和中心点x1, y1 = point1x2, y2 = point2x_center = int((x1 + x2)/2)y_center = int((y1 + y2)/2)center_point = (x_center, y_center)# 缩放点位vector1 = np.array(point1) - np.array(center_point)vector2 = np.array(point2) - np.array(center_point)# 缩放向量scaled_vector1 = vector1 * 3/4scaled_vector2 = vector2 * 3/4# 计算新的终点坐标point1 = np.array(center_point) + scaled_vector1point2 = np.array(center_point) + scaled_vector2# 获取新的坐标和中心点x1, y1 = point1x2, y2 = point2x_center = int((x1 + x2)/2)y_center = int((y1 + y2)/2)# 使用 OpenCV 的 line iterator 来获取沿线的像素line1_iterator = cv2.line(np.zeros_like(image), (int(x1), int(y1)), (int(x_center), int(y_center)), color=1, thickness=2)line2_iterator = cv2.line(np.zeros_like(image), (int(x2), int(y2)), (int(x_center), int(y_center)), color=1, thickness=2)colors1 = []colors2 = []# 图片处理mage = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)h, s, v = cv2.split(mage)# 增强S通道(饱和度)s_enhanced = cv2.equalizeHist(s)enhanced_image = cv2.merge((h, s_enhanced, v))BGR_image = cv2.cvtColor(enhanced_image, cv2.COLOR_HSV2BGR)GRAY_image = cv2.cvtColor(BGR_image, cv2.COLOR_BGR2GRAY)_, bin_image = cv2.threshold(GRAY_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)kernel = np.ones((7, 7), np.uint8)  # 创建一个5x5的方形结构元素bin_image = cv2.dilate(bin_image, kernel, iterations=5)bin_image = cv2.erode(bin_image, kernel, iterations=9)for point in zip(*np.where(line1_iterator)):colors1.append(bin_image[point[0], point[1]])for point in zip(*np.where(line2_iterator)):colors2.append(bin_image[point[0], point[1]])sum_color1 = np.average(colors1)sum_color2 = np.average(colors2) if sum_color1>sum_color2:cv2.arrowedLine(image, (int(x_center), int(y_center)), (int(x1), int(y1)), (255, 0, 0), 2)  # 标记连接点的线段else:cv2.arrowedLine(image, (int(x_center), int(y_center)), (int(x2), int(y2)), (255, 0, 0), 2)  # 标记连接点的线段return imagedef predict(model,image_path):# 在图片上运行推理results = model(image_path) # 推理图片r =  results[0]image = r.plot()  # 获取图片以及检测框等数据boxes = r.obb.xyxyxyxy.detach() # 获取四个角点坐标points =  point_list(boxes) # 获取检测框宽的两个中心点# print(point_list(boxes)) # 遍历每个检测框for i in range(0,len(points)-1, 2):# 将坐标转换为整数point1 = tuple(map(int, points[i])) point2 = tuple(map(int, points[i+1]))image = get_line_colors(image, point1, point2) # 根据两个端点获取线段数据并判断正反new_width = int(image.shape[1] * 0.5)new_height = int(image.shape[0] * 0.5)resized_image =  cv2.resize(image, (new_width, new_height))return resized_image# 加载训练好的模型
model = YOLO("./CBAM300best.pt")
# model.load("../CBAM300best.pt")  
# 图片地址
image_path = r"..\dataset\test\images"
import os
for img in os.listdir(image_path):# if img == "35.jpg":img_path = os.path.join(image_path,img)resized_image = predict(model,img_path)image_name = img.split(".")image_name[0] = image_name[0]+"-1"new_image_path = ".".join(image_name)save_path = os.path.join(image_path,new_image_path)print(save_path)cv2.imwrite(save_path,resized_image)

  效果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/35776.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于大语言模型与知识图谱的智能论文生成工具开发构想

基于大语言模型与知识图谱的智能论文生成工具开发构想 一、研究背景与意义 1.1 学术写作现状分析 #mermaid-svg-FNVHG5EiEgVSCpHK {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-FNVHG5EiEgVSCpHK .error-icon{fil…

学c++的人可以几天速通python?

学了俩天啊&#xff0c;文章写纸上了 还是蛮有趣的

【计算机网络】一二章

一 二 非常棒的例子 相同的传播时延&#xff0c;带宽越大&#xff0c;该链路上所能容纳的比特数越多 相同的传播时延&#xff0c;带宽越大&#xff0c;该链路上所能容纳的比特数越多 往返时间&#xff08;Round-Trip Time&#xff0c;RTT&#xff09;s是指从发送端发送数据分组…

使用Flask和OpenCV 实现树莓派与客户端的视频流传输与显示

使用 Python 和 OpenCV 实现树莓派与客户端的视频流传输与显示 在计算机视觉和物联网领域&#xff0c;经常需要将树莓派作为视频流服务器&#xff0c;通过网络将摄像头画面传输到客户端进行处理和显示。本文将详细介绍如何利用picamera2库、Flask 框架以及 OpenCV 库&#xff…

Kafka跨集群数据备份与同步:MirrorMaker运用

#作者&#xff1a;张桐瑞 文章目录 前言MirrorMaker是什么运行MirrorMaker各个参数的含义 前言 在大多数情况下&#xff0c;我们会部署一套Kafka集群来支撑业务需求。但在某些特定场景下&#xff0c;可能需要同时运行多个Kafka集群。比如&#xff0c;为了实现灾难恢复&#x…

ECharts仪表盘-仪表盘12,附视频讲解与代码下载

引言&#xff1a; ECharts仪表盘&#xff08;Gauge Chart&#xff09;是一种类似于速度表的数据可视化图表类型&#xff0c;用于展示单个或多个变量的指标和状态&#xff0c;特别适用于展示指标的实时变化和状态。本文将详细介绍如何使用ECharts库实现一个仪表盘&#xff0c;…

Harmony OS【 Tabs 导航篇】

设计图&#xff1a; 代码层&#xff1a; Entry Component struct Index {build() {Tabs({ barPosition: BarPosition.End }) {}.scrollable(false).vertical(false).divider({strokeWidth: 0.5,color: #0d182431}).backgroundColor(#F1f3f5).padding({ top: 36, bottom: 28 }…

兆芯大道云行 | 破解高性能云计算数据存储瓶颈

随着数字化转型的加速和数据安全战略的提升&#xff0c;以及国家政策的驱动&#xff0c;政府、金融、能源等关键领域对数据存储的自主可控要求不断提高&#xff0c;传统依赖国外芯片和技术的集中式存储架构面临安全与扩展性瓶颈。例如&#xff0c;政务云场景中原有的非信创服务…

RSI 量化策略实战指南:基于 iTick 报价源的 Python 实现

一、策略原理 相对强弱指标&#xff08;Relative Strength Index, RSI&#xff09;是由 Welles Wilder 提出的经典技术指标&#xff0c;通过计算价格波动的幅度衡量市场超买 / 超卖状态。RSI 取值范围 0-100&#xff0c;常用判断标准&#xff1a; RSI > 70&#xff1a;超买…

12 File文件对象:创建、获取基本信息、遍历文件夹、查找文件;字符集的编解码 (黑马Java视频笔记)

文章目录 File >> 存储数据的方案1. 认识File2. File操作2.1 创建File对象2.2 File操作1&#xff09;对文件对象的信息的操作2&#xff09;文件/文件夹的创建/删除3&#xff09;⭐⭐对文件夹的遍历 3. 方法递归3.1 认识递归3.2 递归算法及其执行流程1) 案例&#xff1a;2…

逻辑派G1 6层高速板学习

逻辑派G1 6层高速板学习 一、原理图分析二、电源分析三、网表导入四、板框导入五、PCB快捷键导入与设置六、模块抓取以及接口器件布局七、模块化布局--预布局&#xff08;先放各模块中的大器件&#xff09;1 HDMI模块布局2 MCU模块布局3 FPGA模块布局4 DDR3模块布局5 DCDC电源模…

图论——广度优先搜索实现

99. 岛屿数量 题目描述 给定一个由 1(陆地)和 0(水)组成的矩阵,你需要计算岛屿的数量。岛屿由水平方向或垂直方向上相邻的陆地连接而成,并且四周都是水域。你可以假设矩阵外均被水包围。 输入描述 第一行包含两个整数 N, M,表示矩阵的行数和列数。 后续 N 行,每行…

PTS-G1K13M RF Generator 1kW / 13MHz User’s Manual 手侧

PTS-G1K13M RF Generator 1kW / 13MHz User’s Manual 手侧

应用分层简介

一、什么是应用分层 应用分层是一种软件开发设计思想&#xff0c;它将应用程序分为多个层次&#xff0c;每个层次各司其职&#xff0c;多个层次之间协同提供完整的功能&#xff0c;根据项目的复杂程度&#xff0c;将项目分为三层或者更多层。 常见的MCV设计模式&#xff0c;就…

conda的基本使用及pycharm里设置conda环境

创建conda环境 conda create --name your_env_name python3.8 把your_env_name换成实际的conda环境名称&#xff0c;python后边的根据自己的需要&#xff0c;选择python的版本。 激活conda环境 conda activate your_env_name 安装相关的包、库 conda install package_name …

E902基于bash与VCS的仿真环境建立

网上看见很多E902仿真的文章&#xff0c;但用到的编译器是类似于这种Xuantie-900-gcc-elf-newlib-x86_64-V3.0.1-20241120&#xff0c;而我按照相应的步骤与对应的编译器&#xff0c;仿真总会报错。后面将编译器换成riscv64-elf-x86_64-20210512&#xff0c;反而成功了。现在开…

PostgreSQL:简介与安装部署

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;精通Java编…

Git使用和原理(3)

1.远程操作 1.1分布式版本控制系统 我们⽬前所说的所有内容&#xff08;⼯作区&#xff0c;暂存区&#xff0c;版本库等等&#xff09;&#xff0c;都是在本地&#xff01;也就是在你的笔记本或者 计算机上。⽽我们的 Git 其实是分布式版本控制系统&#xff01;什么意思呢&a…

ssm框架之mybatis框架讲解

1&#xff0c;Mybatis 1.1 Mybatis概述 1.1.1 Mybatis概念 MyBatis 是一款优秀的持久层框架&#xff0c;用于简化 JDBC 开发 MyBatis 本是 Apache 的一个开源项目iBatis, 2010年这个项目由apache software foundation 迁移到了google code&#xff0c;并且改名为MyBatis 。2…

方法之笔,驭繁于简.绘场景之魂——方法论引领支撑透明化项目之航

关注作者 项目建设中痛难点剖析&#xff1a; 01 项目策划有缺失&#xff0c;目标风险难管控 ①目标设定不合理&#xff0c;由于项目移交交底不充分&#xff0c;造成项目建设目标与前期立项论证偏差过大&#xff0c;达不到建设预期&#xff1b; ②风险评估不足&#xff0c;未…