DETR实现目标检测(一)-训练自己的数据集

1、DETR架构

DETR(Detection Transformer)是一种新型的目标检测模型,由Facebook AI Research (FAIR) 在2020年提出。DETR的核心思想是将目标检测任务视为一个直接的集合预测问题,而不是传统的两步或多步预测问题。这种方法的创新之处在于它直接预测目标的类别和边界框,而不是先生成大量的候选区域,然后再对这些区域进行分类和边界框回归。

DERT的特点主要有二:

一是Transformer结构在CV网络中的应用;

二是提出了一种新的或者说不同的损失函数(Loss Function)。

2、模型下载

模型代码下载地址:

GitHub - facebookresearch/detr: End-to-End Object Detection with Transformers

预训练模型(即权重文件)下载地址:

GitHub - facebookresearch/detr: End-to-End Object Detection with Transformers

下载后放到项目下待使用:

3、labelme标注文件转为coco模式

首先,labelme标注的文件存放在指定位置,包含json和jpg文件

然后,利用代码将labelme的标注文件转化为coco。包含annotations(两个json文件)、train2017(训练集图片)、val2017(验证集图片)

备注:必须严格按照笔者图中的文件命名方式进行命名,训练集清一色命名为instances_train2017.json,验证集清一色命名为instances_val2017.json,这是模型本身的命名要求,用户需要严格遵守。

实现代码如下:

import json
from labelme import utils
import numpy as np
import glob
import PIL.Imageclass MyEncoder(json.JSONEncoder):def default(self, obj):if isinstance(obj, np.integer):return int(obj)elif isinstance(obj, np.floating):return float(obj)elif isinstance(obj, np.ndarray):return obj.tolist()else:return super(MyEncoder, self).default(obj)class labelme2coco(object):def __init__(self, labelme_json=[], save_json_path='./tran.json'):self.labelme_json = labelme_jsonself.save_json_path = save_json_pathself.images = []self.categories = []self.annotations = []# self.data_coco = {}self.label = []self.annID = 1self.height = 0self.width = 0self.save_json()def data_transfer(self):for num, json_file in enumerate(self.labelme_json):with open(json_file, 'r') as fp:data = json.load(fp)  # 加载json文件self.images.append(self.image(data, num))for shapes in data['shapes']:label = shapes['label']if label not in self.label:self.categories.append(self.categorie(label))self.label.append(label)points = shapes['points']  # 这里的point是用rectangle标注得到的,只有两个点,需要转成四个点points.append([points[0][0], points[1][1]])points.append([points[1][0], points[0][1]])self.annotations.append(self.annotation(points, label, num))self.annID += 1def image(self, data, num):image = {}img = utils.img_b64_to_arr(data['imageData'])  # 解析原图片数据height, width = img.shape[:2]image['height'] = heightimage['width'] = widthimage['id'] = num + 1image['file_name'] = data['imagePath'].split('/')[-1]self.height = heightself.width = widthreturn imagedef categorie(self, label):categorie = {}categorie['supercategory'] = 'Cancer'categorie['id'] = len(self.label) + 1  # 0 默认为背景categorie['name'] = labelreturn categoriedef annotation(self, points, label, num):annotation = {}annotation['segmentation'] = [list(np.asarray(points).flatten())]annotation['iscrowd'] = 0annotation['image_id'] = num + 1annotation['bbox'] = list(map(float, self.getbbox(points)))annotation['area'] = annotation['bbox'][2] * annotation['bbox'][3]annotation['category_id'] = self.getcatid(label)  # 注意,源代码默认为1annotation['id'] = self.annIDreturn annotationdef getcatid(self, label):for categorie in self.categories:if label == categorie['name']:return categorie['id']return 1def getbbox(self, points):polygons = pointsmask = self.polygons_to_mask([self.height, self.width], polygons)return self.mask2box(mask)def mask2box(self, mask):"""从mask反算出其边框mask:[h,w]  0、1组成的图片1对应对象,只需计算1对应的行列号(左上角行列号,右下角行列号,就可以算出其边框)"""# np.where(mask==1)index = np.argwhere(mask == 1)rows = index[:, 0]clos = index[:, 1]# 解析左上角行列号left_top_r = np.min(rows)  # yleft_top_c = np.min(clos)  # x# 解析右下角行列号right_bottom_r = np.max(rows)right_bottom_c = np.max(clos)return [left_top_c, left_top_r, right_bottom_c - left_top_c,right_bottom_r - left_top_r]  # [x1,y1,w,h] 对应COCO的bbox格式def polygons_to_mask(self, img_shape, polygons):mask = np.zeros(img_shape, dtype=np.uint8)mask = PIL.Image.fromarray(mask)xy = list(map(tuple, polygons))PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)mask = np.array(mask, dtype=bool)return maskdef data2coco(self):data_coco = {}data_coco['images'] = self.imagesdata_coco['categories'] = self.categoriesdata_coco['annotations'] = self.annotationsreturn data_cocodef save_json(self):self.data_transfer()self.data_coco = self.data2coco()# 保存json文件json.dump(self.data_coco, open(self.save_json_path, 'w'), indent=4, cls=MyEncoder)  # indent=4 更加美观显示if __name__ == '__main__':labelme_json = glob.glob('data/LabelmeData_frame_count/val2017/*.json')  # labelme标注好的.json文件存放目录labelme2coco(labelme_json, 'data/coco_frame_count/annotations/instances_val2017.json')  # 输出结果的存放目录

4、修改训练模型参数

先在pycharm中新建python脚本文件detr_r50_tf.py,代码如下:

import torchpretrained_weights = torch.load('detr-r50-e632da11.pth')num_class = 1  # 类别数
pretrained_weights["model"]["class_embed.weight"].resize_(num_class + 1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class + 1)
torch.save(pretrained_weights, "detr-r50_%d.pth" % num_class)

将其中类别数改成自己数据集的类别数即可,执行完成后会在目录下生成适合自己数据集类别的预训练模型:

然后在models文件夹下打开detr.py,修改其中的类别数(一定要全部保持一致):

最后打开main.py,修改其中的coco_path(数据存放路径)、output_dir(结果输出路径)、device(没有cuda就改为cpu)、resume(自己生成的预训练模型)。

5、执行main.py来开始训练模型

如果不想跑太多了轮可以修改epochs数:

训练好的模型会保存在结果输出路径中:

跑起来的效果是这样的:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/348307.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

升级和维护老旧LabVIEW程序

在升级老旧LabVIEW程序至64位环境时,需要解决兼容性、性能和稳定性等问题。本文从软件升级、硬件兼容性、程序优化、故障修复等多个角度详细分析。具体包括64位迁移注意事项、修复页面跳转崩溃、解决关闭程序后残留进程的问题,确保程序在新环境中的平稳运…

RainBond 制作应用并上架【以ElasticSearch为例】

文章目录 安装 ElasticSearch 集群第 1 步:添加组件第 2 步:查看组件第 3 步:访问组件制作 ElasticSearch 组件准备工作ElasticSearch 集群原理尝试 Helm 安装 ES 集群RainBond 制作 ES 思路源代码Dockerfiledocker-entrypoint.shelasticsearch.yml制作组件第 1 步:添加组件…

服务架构的设计原则

墨菲定律与康威定律 在系统设计的时候,可以依据于墨菲定律 任何事情都没有表面上看起来那么简单所有的事情都会比你预计的时间长可能出错的事总会出错担心的某一个事情的发送,那么它就更有可能发生 在系统划分的时候,可以依据康威定律 系…

python3的基本语法说明一

一. 简介 本文开始学习 python3 的基本语法。 二. python3的基本语法 1. 编码 默认情况下,Python 3 源码文件以 UTF-8 编码,所有字符串都是 unicode 字符串。 当然你也可以为源码文件指定不同的编码: # -*- coding: cp-1252 -*- 上述…

录音转文字软件:一键让工作学习更高效

在职场这个大舞台上,每一场会议都是关键的演出,而会议记录就是这场演出的剧本。但剧本要整理得好,才能让演出更精彩,不是吗? 把那些长串的会议音频变成清晰的文字记录,听起来就像变魔术一样难。但不用担心…

MariaDB数据导入与导出操作演示

文章目录 整个数据库导出导入先删除库然后再导入 参考这里: MariaDB数据库导出导入. 整个数据库 该部分演示:导出数据库,然后重建数据库,并导入数据的整个过程。 导出 Win R ,打开运行输入cmd并回车,然…

Unity与Js通信交互

目录 1.Js给Unity传递消息 2.Unity给Js传递消息 简介: Unity 与 JavaScript 通信交互是指在 Unity 项目中实现与 JavaScript 代码进行数据交换和功能调用的过程。 在 Unity 中,可以通过特定的接口和技术来与外部的 JavaScript 环境进行连接。这使得 Unity 能够利…

【git使用四】git分支理解与操作(详解)

目录 (1)理解git分支 主分支(主线) 功能分支 主线和分支关系 将分支合并到主分支 快速合并 非快速合并 git代码管理流程 (2)理解git提交对象 提交对象与commitID Git如何保存数据 示例讲解 &a…

探索乡村振兴新模式:发挥科技创新在乡村振兴中的引领作用,构建智慧农业体系,助力美丽乡村建设

随着科技的不断进步,乡村振兴工作正迎来前所未有的发展机遇。科技创新作为推动社会发展的重要力量,在乡村振兴中发挥着越来越重要的引领作用。本文旨在探讨如何发挥科技创新在乡村振兴中的引领作用,通过构建智慧农业体系,助力美丽…

六轴机械臂电控介绍

六轴机械臂通常是10kg负载,常用的有埃夫特,库卡,发那科等,其组成部分及用法如下。 本体部分,由底座和六个依次连接的连杆组成,每个连杆由一个电机带动,从底座到末端的轴依次是,轴一带…

超越 Transformer开启高效开放语言模型的新篇章

在人工智能快速发展的今天,对于高效且性能卓越的语言模型的追求,促使谷歌DeepMind团队开发出了RecurrentGemma这一突破性模型。这款新型模型在论文《RecurrentGemma:超越Transformers的高效开放语言模型》中得到了详细介绍,它通过…

SFTP共享配置

SFTP一般指SSH文件传输协议,在计算机领域,SSH文件传输协议(英语:SSH File Transfer Protocol,也称Secret File Transfer Protocol,中文:安全文件传送协议,英文:Secure FT…

HTTP3版本和实现验证

HTTP3协议基于Google的 QUIC 协议,由互联网工程任务组(IETF)来制定。目录还是草案,已经进行到第33版。 HTTP3 是基于 QUIC 协议的 http。传输层是UDPQUIC,应用层仍是HTTP,即request/respose, request里也仍…

【培训】企业档案管理专题(私货)

导读:通过该专题培训,可以系统了解企业档案管理是什么、为什么、怎么做。尤其是对档案的价值认知,如何构建与新质生产力发展相适应的企业档案工作体系将有力支撑企业新质生产力的发展,为企业高质量发展贡献档案力量,提…

网络安全 - ARP 欺骗原理+实验

APR 欺骗 什么是 APR 为什么要用 APR A P R \color{cyan}{APR} APR(Address Resolution Protocol)即地址解析协议,负责将某个 IP 地址解析成对应的 MAC 地址。 在网络通信过程中会使用到这两种地址,逻辑 IP 地址和物理 MAC 地址&…

外网如何访问公司内网服务器?

在现代商业环境中,随着信息技术的快速发展,越来越多的公司有需求让远程用户在外网环境下访问公司内网服务器。这在很大程度上提高了远程办公的灵活性和效率。由于安全和网络限制等问题,实现这一目标并不是一件容易的事情。 在处理这个问题时…

《软件定义安全》之一:SDN和NFV:下一代网络的变革

第1章 SDN和NFV:下一代网络的变革 1.什么是SDN和NFV 1.1 SDN/NFV的体系结构 SDN SDN的体系结构可以分为3层: 基础设施层由经过资源抽象的网络设备组成,仅实现网络转发等数据平面的功能,不包含或仅包含有限的控制平面的功能。…

系统思考与创新解决

刚刚完成了为期两天的《系统思考与创新解决》课程,专门面向前端销售管理者。在这两天里,我们深入讨论了众多与公司当前状况密切相关的议题。通过绘制系统环路图,我们一起探索了包括客户满意度、交付周期、市场份额、研发投入、产能利用率、营…

Kubernetes 集群架构

etcd 集群状态存储:etcd 存储所有 Kubernetes 对象的状态,例如部署、pod、服务、配置映射和机密。配置管理:集群配置的更改存储在 etcd 中,允许 Kubernetes 管理和维护集群的所需状态。 注意:etcd 可能位于 kube-syst…

【ARM Coresight Debug 系列 -- ARMv8/v9 Watchpoint 软件实现地址监控详细介绍】

请阅读【嵌入式开发学习必备专栏 】 文章目录 ARMv8/v9 Watchpoint exceptionsWatchpoint 配置信息读取Execution conditionsWatchpoint data address comparisonsSize of the data accessWatchpoint 软件配置流程Watchpoint Type 使用介绍WT, Bit [20]: Watchpoint TypeLBN, B…