【DETR】训练自己的数据集以及YOLO数据集格式(txt)转化成COCO格式(json)

目录

  • 1.DETR介绍
  • 2.数据集处理
  • 3.转化结果可视化
  • 4.数据集训练
    • 4.1修改pth文件
    • 4.2类别参数修改
    • 4.3训练
  • 5.成功运行!
  • 6.参考文献

1.DETR介绍

DETR(Detection with TRansformers)是基于transformer的端对端目标检测,无NMS后处理步骤,无anchor。
代码链接:https://github.com/facebookresearch/detr
在这里插入图片描述

2.数据集处理

DETR需要的数据集格式为coco格式,这里我是用自己的YOLO格式数据集转化成COCO格式,然后进行训练的。
YOLO数据集的组织格式是:
其中images里面分别存放训练集train和验证集val的图片,labels存放训练集train和验证集val的txt标签。
在这里插入图片描述
要转化成适应DETR模型读取的COCO数据集的组织形式是:
其中train2017存放训练集的图片,val2017存放验证集的图片,
annotations文件夹里面存放train和val的json标签。

在这里插入图片描述
下面是转化代码:

  • 需要进行类别映射,每个类别对应的id分别存放在categories里面,这里我没有用classes.txt文件存放,相当于直接把classes.txt里面的类别写出来了。
  • 我的图片是png格式的,如果图片是jpg格式的,将png改成jpg即可。image_name = filename.replace(‘.txt’, ‘.jpg’)
  • 最后修改文件路径,改成自己的路径,这里最后会输出train和val的json文件,图片不会处理,按上述目录组织形式将图片组织起来即可。
  • 生成的文件夹记得改为instances_train2017.json这种样子
import os
import json
from PIL import Image# 定义类别映射
categories = [{"id": 0, "name": "Double hexagonal column"},{"id": 1, "name": "Flange nut"},{"id": 2, "name": "Hexagon nut"},{"id": 3, "name": "Hexagon pillar"},{"id": 4, "name": "Hexagon screw"},{"id": 5, "name": "Hexagonal steel column"},{"id": 6, "name": "Horizontal bubble"},{"id": 7, "name": "Keybar"},{"id": 8, "name": "Plastic cushion pillar"},{"id": 9, "name": "Rectangular nut"},{"id": 10, "name": "Round head screw"},{"id": 11, "name": "Spring washer"},{"id": 12, "name": "T-shaped screw"}
]def yolo_to_coco(yolo_images_dir, yolo_labels_dir, output_json_path):# 初始化 COCO 数据结构data = {"images": [],"annotations": [],"categories": categories}image_id = 1annotation_id = 1def get_image_size(image_path):with Image.open(image_path) as img:return img.width, img.height# 遍历标签目录for filename in os.listdir(yolo_labels_dir):if not filename.endswith('.txt'):continue  # 只处理 .txt 文件image_name = filename.replace('.txt', '.png')# 如果图片是jpg格式的,将png改成jpg即可。image_path = os.path.join(yolo_images_dir, image_name)if not os.path.exists(image_path):print(f"⚠️ 警告: 图像 {image_name} 不存在,跳过 {filename}")continueimage_width, image_height = get_image_size(image_path)image_info = {"id": image_id,"width": image_width,"height": image_height,"file_name": image_name}data["images"].append(image_info)with open(os.path.join(yolo_labels_dir, filename), 'r') as file:lines = file.readlines()for line in lines:parts = line.strip().split()if len(parts) != 5:print(f"⚠️ 警告: 标签 {filename} 格式错误: {line.strip()}")continuecategory_id = int(parts[0])x_center = float(parts[1]) * image_widthy_center = float(parts[2]) * image_heightbbox_width = float(parts[3]) * image_widthbbox_height = float(parts[4]) * image_heightx_min = int(x_center - bbox_width / 2)y_min = int(y_center - bbox_height / 2)bbox = [x_min, y_min, bbox_width, bbox_height]area = bbox_width * bbox_heightannotation_info = {"id": annotation_id,"image_id": image_id,"category_id": category_id,"bbox": bbox,"area": area,"iscrowd": 0}data["annotations"].append(annotation_info)annotation_id += 1image_id += 1os.makedirs(os.path.dirname(output_json_path), exist_ok=True)with open(output_json_path, 'w') as json_file:json.dump(data, json_file, indent=4)print(f"✅ 转换完成: {output_json_path}")# 输入路径 (YOLO 格式数据集)
yolo_base_dir = "/home/yu/Yolov8/ultralytics-main/mydata0"
yolo_train_images = os.path.join(yolo_base_dir, "images/train")
yolo_train_labels = os.path.join(yolo_base_dir, "labels/train")
yolo_val_images = os.path.join(yolo_base_dir, "images/val")
yolo_val_labels = os.path.join(yolo_base_dir, "labels/val")# 输出路径 (COCO 格式)
coco_base_dir = "/home/yu/Yolov8/ultralytics-main/mydata0_coco"
coco_train_json = os.path.join(coco_base_dir, "annotations/instances_train.json")
coco_val_json = os.path.join(coco_base_dir, "annotations/instances_val.json")# 运行转换
yolo_to_coco(yolo_train_images, yolo_train_labels, coco_train_json)
yolo_to_coco(yolo_val_images, yolo_val_labels, coco_val_json)

3.转化结果可视化

COCO数据集JSON文件格式分为以下几个字段。

{"info": info, # dict"licenses": [license], # list ,内部是dict"images": [image], # list ,内部是dict"annotations": [annotation], # list ,内部是dict"categories": # list ,内部是dict}

可以运行以下脚本查看转化后的标签是否与图片目标对应:

  • 修改代码的json_path和img_path,json_path是标签对应的路径,img_path是图像对应的路径
'''
该代码的功能是:读取图像以及对应bbox的信息
'''
import os
from pycocotools.coco import COCO
from PIL import Image, ImageDraw
import matplotlib.pyplot as pltjson_path = "/home/yu/Yolov8/ultralytics-main/mydata0_coco/annotations/instances_val.json"
img_path = ("/home/yu/Yolov8/ultralytics-main/mydata0_coco/images/val")# load coco data
coco = COCO(annotation_file=json_path)# get all image index info
ids = list(sorted(coco.imgs.keys()))
print("number of images: {}".format(len(ids)))# get all coco class labels
coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])# 遍历前三张图像
for img_id in ids[:3]:# 获取对应图像id的所有annotations idx信息ann_ids = coco.getAnnIds(imgIds=img_id)# 根据annotations idx信息获取所有标注信息targets = coco.loadAnns(ann_ids)# get image file namepath = coco.loadImgs(img_id)[0]['file_name']# read imageimg = Image.open(os.path.join(img_path, path)).convert('RGB')draw = ImageDraw.Draw(img)# draw box to imagefor target in targets:x, y, w, h = target["bbox"]x1, y1, x2, y2 = x, y, int(x + w), int(y + h)draw.rectangle((x1, y1, x2, y2))draw.text((x1, y1), coco_classes[target["category_id"]])# show imageplt.imshow(img)plt.show()

运行该代码,你将会看到你的标签是否对应:
如果目标没有边界框则说明你转化的json不对!
在这里插入图片描述
在这里插入图片描述

4.数据集训练

4.1修改pth文件

将它的pth文件改一下,因为他是用的coco数据集,而我们只需要训练自己的数据集,就是下图这个文件,这是它原本的
在这里插入图片描述
新建一个.py文件,运行下面代码,就会生成一个你数据集所需要的物体数目的pth,记得改类别数!。

import torch
pretrained_weights  = torch.load('detr-r50-e632da11.pth')num_class = 14 #这里是你的物体数+1,因为背景也算一个
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights, "detr-r50_%d.pth"%num_class

这是我们生成的。
在这里插入图片描述

4.2类别参数修改

修改models/detr.py文件,build()函数中,可以将红框部分的代码都注释掉,直接设置num_classes为自己的类别数+1
因为我的类别数是13,所以我这里num_classes=14
在这里插入图片描述

4.3训练

修改main.py文件的epochs、lr、batch_size等训练参数:
以下这些参数都在get_args_parser()函数里面。
在这里插入图片描述

修改自己的数据集路径:
在这里插入图片描述
设置输出路径:
在这里插入图片描述

修改resume为自己的预训练权重文件路径
这里就是你刚才运行脚本生成的pth文件的路径:
在这里插入图片描述
运行main.py文件
或者可以通过命令行运行:

python main.py --dataset_file "coco" --coco_path "/home/yu/Yolov8/ultralytics-main/mydata0_coco" --epoch 300 --lr=1e-4 --batch_size=8 --num_workers=4 --output_dir="outputs" --resume="detr_r50_14.pth"

5.成功运行!

在这里插入图片描述

6.参考文献

1.【DETR】训练自己的数据集-实践笔记
2. yolo数据集格式(txt)转coco格式,方便mmyolo转标签格式
3. windows10复现DEtection TRansformers(DETR)并实现自己的数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/38558.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HashMap学习总结——JDK17

文章目录 HashMap构造方法HashMap(int initialCapacity, float loadFactor)loadFactor 加载因子initialCapacity 初始容量tableSizeFor(int cap) 计算前导零 HashMap(Map<? extends K, ? extends V> m) put(K key, V value)hash(Object key) 求hash值putVal(int hash, …

Linux:进程信号

✨✨所属专栏&#xff1a;Linux✨✨ ✨✨作者主页&#xff1a;嶔某✨✨ Linux&#xff1a;进程信号 在讲信号之前&#xff0c;我们先来从生活中的事情来确定信号的一些特性。 我在网上买了商品&#xff0c;我在等快递。但是在快递没来之前我知道快递来的时候我应该怎么处理。…

c#知识点补充2

1.非静态类能否调用静态方法可以 2.对string类型扩展方法&#xff0c;如何进行 类用静态类&#xff0c;参数是this 调用如下 3.out的用法 一定要给a赋值 这种写法不行 这样才行 4.匿名类 5.委托的使用 无论是匿名委托&#xff0c;还是具命委托&#xff0c;委托实例化后一定要…

0322-数据库、前后端

前端 <!DOCTYPE html> <html> <head> <meta charset"UTF-8"> <title>Insert title here</title> <script srcjs/jquery-3.7.1.min.js></script> <script> //jquaryajax发起请求 //传参形式不同 post用data{}…

Spring Boot02(数据库、Redis)02---java八股

MySQL和Redis的区别&#xff1f; 1. 数据类型&#xff1a; MySQL是一种关系型数据库&#xff0c;表结构化存储&#xff0c;使用 SQL 查询。支持表、列、行等结构化数据。 Redis是一种基于内存的缓存系统&#xff0c;支持多种数据结构&#xff0c;如字符串、哈希表、列表、集合、…

VulnHub-Web-Machine-N7通关攻略

一、信息收集 第一步&#xff1a;确定靶机IP为192.168.0.107 第二步&#xff1a;扫描后台及开放端口 第三步&#xff1a;进行敏感目录及文件扫描 http://192.168.0.107/index.html (CODE:200|SIZE:1620) http://192.168.0.107/server-status (CODE:403|SIZ…

【AI News | 20250322】每日AI进展

AI Repos 1、DeTikZify 可以把草图或图形转换成TikZ代码的模型&#xff0c;可用来绘制复杂的科学图表&#xff0c;输入草图或文字描述即可转换成TikZ代码。DeTikZify强大的地方在于它能理解图表的语义信息&#xff0c; 能识别图表中的不同组成部分及其含义&#xff0c;比如坐标…

三主热备架构

1.要求 角色主机名软件IP地址用户client192.168.72.90keepalivedvip192.168.72.100masterserverAkeepalived, nginx192.168.72.30backupserverBkeepalived, nginx192.168.72.31backupserverCkeepalived, nginx192.168.72.32webtomcat1tomcat192.168.72.41webtomcat2tomcat192.1…

LiteratureReading:[2023] GPT-4: Technical Report

文章目录 一、文献简明&#xff08;zero&#xff09;二、快速预览&#xff08;first&#xff09;1、标题分析2、作者介绍3、引用数4、摘要分析&#xff08;1&#xff09;翻译&#xff08;2&#xff09;分析 5、总结分析&#xff08;1&#xff09;翻译&#xff08;2&#xff09;…

java使用Apache POI 操作word文档

项目背景&#xff1a; 当我们对一些word文档&#xff08;该文档包含很多的标题比如 1.1 &#xff0c;1.2 &#xff0c; 1.2.1.1&#xff0c; 1.2.2.3&#xff09;当我们删除其中一项或者几项时&#xff0c;需要手动的对后续的进行补充。该功能主要是对标题进行自动的补充。 具…

OpenHarmony 开源鸿蒙北向开发——linux使用make交叉编译第三方库

这几天搞鸿蒙&#xff0c;需要编译一些第三方库到鸿蒙系统使用。 头疼死了&#xff0c;搞了一个多星期总算搞定了。 开贴记坑。 一、SDK下载 1.下载 在linux下使用命令 wget https://cidownload.openharmony.cn/version/Master_Version/OpenHarmony_5.1.0.54/20250313_02…

SVN简明教程——下载安装使用

SVN教程目录 一、开发中的实际问题二、简介2.1 版本控制2.2 Subversion2.3 Subversion的优良特性2.4 工作原理2.5 SVN基本操作 三、Subversion的安装与配置1. 服务器端程序版本2. 下载源码包3. 下载二进制安装包4. 安装5. 配置版本库① 为什么要配置版本库&#xff1f;② 创建目…

OpenCV旋转估计(1)用于估计图像间仿射变换关系的类cv::detail::AffineBasedEstimator

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 基于仿射变换的估计器。 这种估计器使用匹配器估算的成对变换来为每个相机估算最终的变换。 cv::detail::AffineBasedEstimator 是 OpenCV 库中…

大数据学习栈记——HBase安装

本文介绍大数据技术中流行的非关系型数据库HBase的安装&#xff0c;操作系统&#xff1a;Ubuntu24.04 安装Zookeeper 安装HBase前需要先安装Zookeeper&#xff0c;HBase使用Zookeeper作为其分布式协同服务&#xff0c;存储了HBase集群的元数据信息&#xff0c;并提供了分布式…

SpringBoot+VUE(Ant Design Vue)实现图片下载预览功能

目录 背景 1.后端实现下载接口 2.前端请求实现 第一步&#xff1a;导入api 第二步&#xff1a;请求接口 3.前端展示实现 4.实现效果展示 5.总结 背景 这段时间通过SpringBootVUE(Ant Design Vue)框架做了一个项目&#xff0c;但是在图片下载&#xff0c;展示的时候在网…

Java 推送钉钉应用消息

前言&#xff1a; 本文的目的是通过手机号获取钉钉成员的userid&#xff0c;实现钉钉应用的消息推送。 一、创建钉钉应用 登录钉钉开放平台 二、应用相关凭证 需要获取 Client ID (原 AppKey 和 SuiteKey) Client Secret (原 AppSecret 和 SuiteSecret) App ID 原企业内部…

SpringCloud介绍

什么是SpringCloud&#xff1f; SpringCloud 是分布式微服务架构下的一站式解决方案&#xff0c;是各个微服务架构落地技术的集合体&#xff0c;俗称微服务全家桶。 官方介绍&#xff1a; SpringCloud是基于SpringBoot提供了一套微服务解决方案&#xff0c;包括服务注册与发现…

YOLOv11 目标检测

本文章不再赘述anaconda的下载以及虚拟环境的配置&#xff0c;博主使用的python版本为3.8 1.获取YOLOv11的源工程文件 链接&#xff1a;GitHub - ultralytics/ultralytics: Ultralytics YOLO11 &#x1f680; 直接下载解压 2.需要自己准备的文件 文件结构如下&#xff1a;红…

【Linux】——环境变量与进程地址空间

文章目录 环境变量环境变量的概念常见的环境变量PATH相关指令 main的三个参数前两个参数第三个参数 程序地址空间进程地址空间 环境变量 环境变量的概念 环境变量一般是指在操作系统中用来指定操作系统运行环境的一些参数&#xff0c;将来会以shell的形式传递给所有进程&…

Kafka--常见问题

1.为什么要使用 Kafka&#xff0c;起到什么作用 Kafka是一个高吞吐量、分布式、基于发布订阅的消息系统&#xff0c;它主要用于处理实时数据流 Kafka 设计上支持高吞吐量的消息传输&#xff0c;每秒可以处理数百万条消息。它能够在处理大量并发请求时&#xff0c;保持低延迟和…