RT-DETR+Flask实现目标检测推理案例

今天,带大家利用RT-DETR(我们可以换成任意一个模型)+Flask来实现一个目标检测平台小案例,其实现效果如下:

目标检测案例

这个案例很简单,就是让我们上传一张图像,随后选择一下置信度,即可检测出图像中的目标,那么具体该如何实现呢?

RT-DETR模型推理

在先前的学习过程中,博主对RT-DETR进行来了简要的介绍,作为百度提出的实时性目标检测模型,其无论是速度还是精度均取得了较为理想的效果,今天则主要介绍一下RT-DETR的推理过程,与先前使用DETR中使用pth权重与网络结构相结合的推理方式不同,RT-DETR中使用的是onnx这种权重文件,因此,我们需要先对onnx文件进行一个简单了解:
在这里插入图片描述

ONNX模型文件

import onnx
# 加载模型
model = onnx.load('onnx_model.onnx')
# 检查模型格式是否完整及正确
onnx.checker.check_model(model)
# 获取输出层,包含层名称、维度信息
output = self.model.graph.output
print(output)

在原本的DETR类目标检测算法中,推理是采用权重文件与模型结构代码相结合的方式,而在RT-DETR中,则采用onnx模型文件来进行推理,即只需要该模型文件即可。

首先是将pth文件与模型结构进行匹配,从而导出onnx模型文件

"""by lyuwenyu
"""import os 
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))import argparse
import numpy as np from src.core import YAMLConfigimport torch
import torch.nn as nn def main(args, ):"""main"""cfg = YAMLConfig(args.config, resume=args.resume)if args.resume:checkpoint = torch.load(args.resume, map_location='cpu') if 'ema' in checkpoint:state = checkpoint['ema']['module']else:state = checkpoint['model']else:raise AttributeError('only support resume to load model.state_dict by now.')# NOTE load train mode state -> convert to deploy modecfg.model.load_state_dict(state)class Model(nn.Module):def __init__(self, ) -> None:super().__init__()self.model = cfg.model.deploy()self.postprocessor = cfg.postprocessor.deploy()print(self.postprocessor.deploy_mode)def forward(self, images, orig_target_sizes):outputs = self.model(images)return self.postprocessor(outputs, orig_target_sizes)model = Model()dynamic_axes = {'images': {0: 'N', },'orig_target_sizes': {0: 'N'}}data = torch.rand(1, 3, 640, 640)size = torch.tensor([[640, 640]])torch.onnx.export(model, (data, size), args.file_name,input_names=['images', 'orig_target_sizes'],output_names=['labels', 'boxes', 'scores'],dynamic_axes=dynamic_axes,opset_version=16, verbose=False)if args.check:import onnxonnx_model = onnx.load(args.file_name)onnx.checker.check_model(onnx_model)print('Check export onnx model done...')if args.simplify:import onnxsimdynamic = True input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else Noneonnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)onnx.save(onnx_model_simplify, args.file_name)print(f'Simplify onnx model {check}...')
if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--config', '-c',  default="D:\graduate\programs\RT-DETR-main\RT-DETR-main//rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )parser.add_argument('--resume', '-r', default="D:\graduate\programs\RT-DETR-main\RT-DETR-main/rtdetr_pytorch/tools\output/rtdetr_r18vd_6x_coco\checkpoint0024.pth",type=str, )parser.add_argument('--file-name', '-f', type=str, default='model.onnx')parser.add_argument('--check',  action='store_true', default=False,)parser.add_argument('--simplify',  action='store_true', default=False,)args = parser.parse_args()main(args)

随后,便是利用onnx模型文件进行目标检测推理过程了
onnx也有自己的一套流程:

onnx前向InferenceSession的使用

关于onnx的前向推理,onnx使用了onnxruntime计算引擎。
onnx runtime是一个用于onnx模型的推理引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准–ONNX,顺路提供了一个专门用于ONNX模型推理的引擎(onnxruntime)。

import onnxruntime
# 创建一个InferenceSession的实例,并将模型的地址传递给该实例
sess = onnxruntime.InferenceSession('onnxmodel.onnx')
# 调用实例sess的润方法进行推理
outputs = sess.run(output_layers_name, {input_layers_name: x})

推理详细代码

推理代码如下:

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensorif __name__ == "__main__":##################classes = ['car','truck',"bus"]################### print(onnx.helper.printable_graph(mm.graph))#############img_path = "1.jpg"#############im = Image.open(img_path).convert('RGB')im = im.resize((640, 640))im_data = ToTensor()(im)[None]print(im_data.shape)size = torch.tensor([[640, 640]])sess = ort.InferenceSession("model.onnx")import timestart = time.time()output = sess.run(output_names=['labels', 'boxes', 'scores'],#output_names=None,input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()})end = time.time()fps = 1.0 / (end - start)print(fps)# print(type(output))# print([out.shape for out in output])labels, boxes, scores = outputdraw = ImageDraw.Draw(im)thrh = 0.6for i in range(im_data.shape[0]):scr = scores[i]lab = labels[i][scr > thrh]box = boxes[i][scr > thrh]print(i, sum(scr > thrh))#print(lab)print(f'box:{box}')for l, b in zip(lab, box):draw.rectangle(list(b), outline='red',)print(l.item())draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )#############im.save('2.jpg')#############

前端代码

前端代码包含两部分,一个是上传页面,一个是显示页面

上传页面如下:

<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /><title></title><script src="http://www.jq22.com/jquery/jquery-1.10.2.js"></script><style>#addCommodityIndex {text-align: center;width: 300px;height: 340px;position: absolute;left: 50%;top: 50%;margin: -200px 0 0 -200px;border: solid #ccc 1px;padding: 35px;}#imghead {cursor: pointer;}.btn {width: 100%;height: 40px;text-align: center;}</style><link rel="stylesheet" href="../static/css/bootstrap.min.css"  crossorigin="anonymous">
</head><body><div id="addCommodityIndex"><h2>目标检测</h2><div class="form-group row"><form id="upload"  action="/upload" enctype="multipart/form-data" method="POST"><img src=""><div class="form-group row"><label>上传图像</label><input type="file" class="form-control"  name='file'></div><div class="form-group row"><label>选择置信度</label><select class="form-control" name="score" id="exampleFormControlSelect1"><option value="0.5">0.5</option><option value="0.6">0.6</option><option value="0.7">0.7</option><option value="0.8">0.8</option><option value="0.9">0.9</option></select></div><div class="form-group row"><div class="btn"><input type="submit" class="btn btn-success" value="提交图像" /></div></div></form></div></div></body>
</html>

显示页面:

<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /><title></title><script src="http://www.jq22.com/jquery/jquery-1.10.2.js"></script><style>#addCommodityIndex {text-align: center;position: absolute;left: 40%;top: 50%;margin: -200px 0 0 -200px;border: solid #ccc 1px;}#imghead {cursor: pointer;}.result {width: 100%;height: 100%;text-align: center;}</style><link rel="stylesheet" href="../static/css/bootstrap.min.css"  crossorigin="anonymous">
</head><body><div id="addCommodityIndex">
<div class="card mb-3" style="max-width: 680px;"><div class="row no-gutters"><div class="col-md-5"><img src="../static/img/result.jpg" class="result"></div><div class="col-md-5"><div class="card-body"><h5 class="card-title">检测结果</h5><p class="card-text">目标数量:{{num}}</p><p class="card-text">检测速度:{{fps}}/</p><a  href="/home" class="btn btn-success">继续提交</a></div></div></div>
</div>
</div>
</body>
</html>

Flask框架代码:

# -*- coding: utf-8 -*-
from flask import Flask,request,render_template
import json
import os
import time
app = Flask(__name__)
import infer
@app.route('/home',methods=['GET'])
def home():return render_template('upload.html')@app.route('/upload',methods=['GET','POST'])
def upload():if request.method == 'POST':f = request.files['file'] #获取数据流rootPath = os.path.dirname(os.path.abspath(__file__)) #根目录路径#创建存储文件的文件夹,使用时间戳防止重名覆盖file_path = 'static/upload/' + str(int(time.time()))absolute_path = os.path.join(rootPath,file_path).replace('\\','/') #存储文件的绝对路径,window路径显示\\要转化/if not os.path.exists(absolute_path): #不存在改目录则会自动创建os.makedirs(absolute_path)save_file_name = os.path.join(absolute_path,f.filename).replace('\\','/') #文件存储路径(包含文件名)f.save(save_file_name)score=request.values.to_dict().get("score")num,fps=infer.inference(save_file_name,score)#return json.dumps({'code':200,'url':url_path},ensure_ascii=False)return render_template("show.html",num=num,fps=fps)app.run(port='5000',debug=True)

上述项目博主已经上传到github上

git init
git add README.md
git commit -m "first commit"
git branch -M main
git remote add origin https://github.com/pengxiang1998/rt-detr.git
git push -u origin main

项目地址

在使用onnx时,安装了onnxruntime后,出现了下面的错误:

ImportError: cannot import name 'create_and_register_allocator_v2' from 'onnxruntime.capi._pybind_state'

这是由于onnxruntime-gpu版本与CUDA、CuDNN版本不匹配导致的,可以查看下面的网址来查看匹配版本

https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

在这里插入图片描述
随后又出现错误:

> This ORT build has ['TensorrtExecutionProvider',
> 'CUDAExecutionProvider', 'CPUExecutionProvider'] enabled. Since ORT
> 1.9, you are required to explicitly set the providers parameter when instantiating InferenceSession. For example,
> onnxruntime.InferenceSession(...,
> providers=['TensorrtExecutionProvider',

这是由于InferenceSession中没有提供对应的provider,修改代码如下:

if torch.cuda.is_available():print("GPU")sess = ort.InferenceSession("model.onnx", None, providers=["CUDAExecutionProvider"])else:print("CPU")sess= ort.InferenceSession("model.onnx", None)

随后运行,发现安装了onnxruntime-gpu后的速度竟然满了下来,fps仅为0.2,而原本使用onnxruntime的fps则为7左右,这到底是怎么回事呢?

在这里插入图片描述

YOLO集成推理

而在YOLO集成的RT-DETR项目中,训练得到的权重 文件为.pt,在推理时需要与RT-DETR搭配使用,从而实现推理过程:
需要注意的是,由于YOLO里面集成了多种模型,因此为了具有适配性,其代码都具有通用性

from ultralytics.models import RTDETR
if __name__ == '__main__':model=RTDETR("weights/best.pt")model.predict(source="images/1.mp4",save=True,conf=0.6)

随后执行predict,代码如下:

def predict(self,source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,stream: bool = False,predictor=None,**kwargs,) -> list:if source is None:source = ASSETSLOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(x in ARGV for x in ("predict", "track", "mode=predict", "mode=track"))custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"}  # method defaultsargs = {**self.overrides, **custom, **kwargs}  # highest priority args on the rightprompts = args.pop("prompts", None)  # for SAM-type modelsif not self.predictor:self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)self.predictor.setup_model(model=self.model, verbose=is_cli)else:  # only update args if predictor is already setupself.predictor.args = get_cfg(self.predictor.args, args)if "project" in args or "name" in args:self.predictor.save_dir = get_save_dir(self.predictor.args)if prompts and hasattr(self.predictor, "set_prompts"):  # for SAM-type modelsself.predictor.set_prompts(prompts)return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)

这部分代码在功能上具有复用性,因此在理解上存在一定难度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/379159.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ARM体系结构和接口技术(六)KEY按键实验① 按键轮询检测

文章目录 一、按键轮询&#xff08;一&#xff09;分析按键的电路连接1. 按键原理图2. 按键消抖 二、分析芯片手册&#xff08;一&#xff09; GPIO章节&#xff08;二&#xff09;RCC章节 三、代码&#xff08;一&#xff09;key.c&#xff08;二&#xff09;key.h 一、按键轮…

Python 魔法方法小结

目录 引言 &#x1f31f; 实例一&#xff1a;__init__构造方法 &#x1f31f; 实例二&#xff1a;__str__和__repr__方法 &#x1f31f; 实例三&#xff1a;__add__运算符重载 &#x1f31f; 实例四&#xff1a;__len__方法 &#x1f31f; 实例五&#xff1a;__getitem__…

从人工巡检到智能防控:智慧油气田安全生产的新视角

一、背景需求 随着科技的飞速发展&#xff0c;视频监控技术已成为各行各业保障安全生产、提升管理效率的重要手段。特别是在油气田这一特殊领域&#xff0c;由于其工作环境复杂、安全风险高&#xff0c;传统的监控方式已难以满足实际需求。因此&#xff0c;基于视频监控AI智能…

C#绘制阻抗圆图初步

阻抗圆图&#xff0c;或者叫史密斯图&#xff0c;是无线电设计方面用的&#xff1b; 基本的阻抗圆图如下&#xff0c; 下面尝试用C#能不能画一下&#xff1b; 先在网上找一个画坐标的C#类&#xff0c;它的效果如下&#xff1b; 自己再增加一个函数&#xff0c;可以绘制中心在…

【嵌入式Linux】<总览> 网络编程(更新中)

文章目录 前言 一、网络知识概述 1. 网路结构分层 2. socket 3. IP地址 4. 端口号 5. 字节序 二、网络编程常用API 1. socket函数 2. bind函数 3. listen函数 4. accept函数 5. connect函数 6. read和recv函数 7. write和send函数 三、TCP编程 1. TCP介绍 2.…

Android-- 集成谷歌地图

引言 项目需求需要在谷歌地图&#xff1a; 地图展示&#xff0c;设备点聚合&#xff0c;设备站点&#xff0c;绘制点和区域等功能。 我只针对我涉及到的技术做一下总结&#xff0c;希望能帮到开始接触谷歌地图的伙伴们。 集成步骤 1、在项目的modle的build.gradle中添加依赖如…

WSL-Ubuntu20.04部署环境配置

1.更换Ubuntu软件仓库镜像源 为了在WSL上使用TensorRT进行推理加速&#xff0c;需要安装以下环境&#xff0c;下面将按以下顺序分别介绍安装、验证以及删除环境&#xff1a; #1.C环境配置 gcc、gdb、g #2.gpu环境 cuda、cudnn #3.Cmake环境 CMake #4.OpenCV环境 OpenCV #5.Ten…

在mybatis-plus中关于@insert注解自定义批处理sql导致其雪花算法失效而无法自动生成id的解决方法

受到这位作者的启发 > 原文在点这里 为了自己实现批量插入&#xff0c;我在mapper层使用insert注解写了一段自定义sql //自定义的批量插入方法 Insert("<script>" "insert into rpt_material_hour(id,sample_time,rounding_time,cur_month,machine_no…

Web3时代的教育技术革新:智能合约在学习管理中的应用

随着区块链技术的发展和普及&#xff0c;Web3时代正在为教育技术带来前所未有的革新和机遇。智能合约作为区块链技术的核心应用之一&#xff0c;不仅在金融和供应链管理等领域展示了其巨大的潜力&#xff0c;也在教育领域中逐渐探索和应用。本文将探讨智能合约在学习管理中的具…

分词任务介绍-(十)

分词任务 中文分词正向最大匹配实现方式一实现方式二 反向最大匹配双向最大匹配jieba分词上述分词方法的缺点总结基于机器学习 总结分词技术经验总结 中文分词 正向最大匹配 分词的步骤 1.收集整理一个词表&#xff0c;类似于字典。如下图 2.对于待分词的句子&#xff0c;或者…

总结单例模式的写法

一、单例模式的概念 1.1 单例模式的概念 单例模式&#xff08;Singleton Pattern&#xff09;是 Java 中最简单的设计模式之一。这种类型的设计模式属于创建型模式&#xff0c;它提供了一种创建对象的最佳方式。就是当前进程确保一个类全局只有一个实例。 1.2 单例模式的优…

2024 China Joy 前瞻 | 腾讯网易发新作,网易数智携游戏前沿科技、创新产品以及独家礼盒,精彩不断!

今年上半年&#xff0c;CES、MWC和AWE三大国际科技展轮番轰炸&#xff0c;吸引全球科技爱好者的高度关注&#xff0c;无论是新潮的科技产品&#xff0c;还是对人工智能的探索&#xff0c;每一项展出的技术和产品都引起了市场的热议。而到了下半年&#xff0c;一年一度的China J…

Kafka消息队列python开发环境搭建

目录 引言 Kafka 的核心概念和组件 Kafka 的主要特性 使用场景 申请云服务器 安装docker及docker-compose VSCODE配置 开发环境搭建 搭建Kafka的python编程环境 Kafka的python编程示例 引言 Apache Kafka 是一个分布式流处理平台&#xff0c;由 LinkedIn 开发并在 2…

Android View的绘制流程

1.不管是View的添加&#xff0c;还是调用View的刷新方法invalidate()或者requestLayout()&#xff0c;绘制都是从ViewRootImpl的scheduleTraversals()方法开始 void scheduleTraversals() {if (!mTraversalScheduled) {mTraversalScheduled true;mTraversalBarrier mHandler…

SpringCloud教程 | 第九篇: 使用API Gateway

1、参考资料 SpringCloud基础篇-10-服务网关-Gateway_springcloud gateway-CSDN博客 2、先学习路由&#xff0c;参考了5.1 2.1、建了一个cloudGatewayDemo&#xff0c;这是用来配置网关的工程&#xff0c;配置如下&#xff1a; http://localhost:18080/aaa/name 该接口代码如…

科普文:详解23种设计模式

概叙 设计模式是对大家实际工作中写的各种代码进行高层次抽象的总结&#xff0c;其中最出名的当属 Gang of Four&#xff08;GoF&#xff09;的分类了&#xff0c;他们将设计模式分类为 23 种经典的模式&#xff0c;根据用途我们又可以分为三大类&#xff0c;分别为创建型模式…

等保-Linux等保测评

等保-Linux等保测评 1.查看相应文件&#xff0c;账户xiaoming的密码设定多久过期 rootdengbap:~# chage -l xiaoming Last password change : password must be changed Password expires : pass…

理解类与对象:面向对象基础

目录 1. 类的定义1.1 格式1.2 访问限定符1.3 类域 2.实例化2.1 实例化概念2.2 对象大小 3.this指针 1. 类的定义 1.1 格式 class为定义类的关键字&#xff0c;Date为类的名字&#xff0c;{ }中为类的主体&#xff0c;注意类定义结束后面的分号不能省略。类体中内容称为类的成…

【博士每天一篇文献-算法】连续学习算法之HNet:Continual learning with hypernetworks

阅读时间&#xff1a;2023-12-26 1 介绍 年份&#xff1a;2019 作者&#xff1a;Johannes von Oswald&#xff0c;Google Research&#xff1b;Christian Henning&#xff0c;EthonAI AG&#xff1b;Benjamin F. Grewe&#xff0c;苏黎世联邦理工学院神经信息学研究所 期刊&a…

如何在项目中打印sql和执行的时间

目标&#xff1a;打印DAO方法中sql和执行的时间 一种方式是去实现Mybatis的拦截器Interceptor &#xff0c;比较麻烦&#xff1b; 这里介绍一种比较简单的实现方式&#xff1b; 1、如何打印sql&#xff1f; 配置文件加这个可以打印出com.zhenhui.ids.busi.watch包下执行的sq…