使用onnxruntime加载YOLOv8生成的onnx文件进行目标检测

      在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe  工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下:

path: ../datasets/melon # dataset root dir
train: images/train # train images (relative to 'path')
val: images/val  # val images (relative to 'path')
test: # test images (optional)# Classes
names:0: watermelon1: wintermelon

      使用以下python脚本进行训练生成onnx文件

import argparse
import colorama
from ultralytics import YOLOdef parse_args():parser = argparse.ArgumentParser(description="YOLOv8 train")parser.add_argument("--yaml", required=True, type=str, help="yaml file")parser.add_argument("--epochs", required=True, type=int, help="number of training")parser.add_argument("--task", required=True, type=str, choices=["detect", "segment"], help="specify what kind of task")args = parser.parse_args()return argsdef train(task, yaml, epochs):if task == "detect":model = YOLO("yolov8n.pt") # load a pretrained modelelif task == "segment":model = YOLO("yolov8n-seg.pt") # load a pretrained modelelse:print(colorama.Fore.RED + "Error: unsupported task:", task)raiseresults = model.train(data=yaml, epochs=epochs, imgsz=640) # train the modelmetrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings rememberedmodel.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support# model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)model.export(format="torchscript") # libtorchif __name__ == "__main__":colorama.init()args = parse_args()train(args.task, args.yaml, args.epochs)print(colorama.Fore.GREEN + "====== execution completed ======")

      以下是使用onnxruntime接口加载onnx文件进行目标检测的实现代码:

namespace {constexpr bool cuda_enabled{ false };
constexpr int image_size[2]{ 640, 640 }; // {height,width}, input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 6, 8400)
constexpr float model_score_threshold{ 0.45 }; // confidence threshold
constexpr float model_nms_threshold{ 0.50 }; // iou threshold#ifdef _MSC_VER
constexpr char* onnx_file{ "../../../data/best.onnx" };
constexpr char* torchscript_file{ "../../../data/best.torchscript" };
constexpr char* images_dir{ "../../../data/images/predict" };
constexpr char* result_dir{ "../../../data/result" };
constexpr char* classes_file{ "../../../data/images/labels.txt" };
#else
constexpr char* onnx_file{ "data/best.onnx" };
constexpr char* torchscript_file{ "data/best.torchscript" };
constexpr char* images_dir{ "data/images/predict" };
constexpr char* result_dir{ "data/result" };
constexpr char* classes_file{ "data/images/labels.txt" };
#endifstd::vector<std::string> parse_classes_file(const char* name)
{std::vector<std::string> classes;std::ifstream file(name);if (!file.is_open()) {std::cerr << "Error: fail to open classes file: " << name << std::endl;return classes;}std::string line;while (std::getline(file, line)) {auto pos = line.find_first_of(" ");classes.emplace_back(line.substr(0, pos));}file.close();return classes;
}auto get_dir_images(const char* name)
{std::map<std::string, std::string> images; // image name, image path + image namefor (auto const& dir_entry : std::filesystem::directory_iterator(name)) {if (dir_entry.is_regular_file())images[dir_entry.path().filename().string()] = dir_entry.path().string();}return images;
}void draw_boxes(const std::vector<std::string>& classes, const std::vector<int>& ids, const std::vector<float>& confidences,const std::vector<cv::Rect>& boxes, const std::string& name, cv::Mat& frame)
{if (ids.size() != confidences.size() || ids.size() != boxes.size() || confidences.size() != boxes.size()) {std::cerr << "Error: their lengths are inconsistent: " << ids.size() << ", " << confidences.size() << ", " << boxes.size() << std::endl;return;}std::cout << "image name: " << name << ", number of detections: " << ids.size() << std::endl;std::random_device rd;std::mt19937 gen(rd());std::uniform_int_distribution<int> dis(100, 255);for (auto i = 0; i < ids.size(); ++i) {auto color = cv::Scalar(dis(gen), dis(gen), dis(gen));cv::rectangle(frame, boxes[i], color, 2);std::string class_string = classes[ids[i]] + ' ' + std::to_string(confidences[i]).substr(0, 4);cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);cv::Rect text_box(boxes[i].x, boxes[i].y - 40, text_size.width + 10, text_size.height + 20);cv::rectangle(frame, text_box, color, cv::FILLED);cv::putText(frame, class_string, cv::Point(boxes[i].x + 5, boxes[i].y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);}//cv::imshow("Inference", frame);//cv::waitKey(-1);std::string path(result_dir);path += "/" + name;cv::imwrite(path, frame);
}std::wstring ctow(const char* str)
{constexpr size_t len{ 128 };wchar_t wch[len];swprintf(wch, len, L"%hs", str);return std::wstring(wch);
}float image_preprocess(const cv::Mat& src, cv::Mat& dst)
{cv::cvtColor(src, dst, cv::COLOR_BGR2RGB);float resize_scales{ 1. };if (src.cols >= src.rows) {resize_scales = src.cols * 1.f / image_size[1];cv::resize(dst, dst, cv::Size(image_size[1], static_cast<int>(src.rows / resize_scales)));} else {resize_scales = src.rows * 1.f / image_size[0];cv::resize(dst, dst, cv::Size(static_cast<int>(src.cols / resize_scales), image_size[0]));}cv::Mat tmp = cv::Mat::zeros(image_size[0], image_size[1], CV_8UC3);dst.copyTo(tmp(cv::Rect(0, 0, dst.cols, dst.rows)));dst = tmp;return resize_scales;
}template<typename T>
void image_to_blob(const cv::Mat& src, T* blob)
{for (auto c = 0; c < 3; ++c) {for (auto h = 0; h < src.rows; ++h) {for (auto w = 0; w < src.cols; ++w) {blob[c * src.rows * src.cols + h * src.cols + w] = (src.at<cv::Vec3b>(h, w)[c]) / 255.f;}}}
}void post_process(const float* data, int rows, int stride, float xfactor, float yfactor, const std::vector<std::string>& classes,cv::Mat& frame, const std::string& name)
{std::vector<int> class_ids;std::vector<float> confidences;std::vector<cv::Rect> boxes;for (auto i = 0; i < rows; ++i) {const float* classes_scores = data + 4;cv::Mat scores(1, classes.size(), CV_32FC1, (float*)classes_scores);cv::Point class_id;double max_class_score;cv::minMaxLoc(scores, 0, &max_class_score, 0, &class_id);if (max_class_score > model_score_threshold) {confidences.push_back(max_class_score);class_ids.push_back(class_id.x);float x = data[0];float y = data[1];float w = data[2];float h = data[3];int left = int((x - 0.5 * w) * xfactor);int top = int((y - 0.5 * h) * yfactor);int width = int(w * xfactor);int height = int(h * yfactor);boxes.push_back(cv::Rect(left, top, width, height));}data += stride;}std::vector<int> nms_result;cv::dnn::NMSBoxes(boxes, confidences, model_score_threshold, model_nms_threshold, nms_result);std::vector<int> ids;std::vector<float> confs;std::vector<cv::Rect> rects;for (size_t i = 0; i < nms_result.size(); ++i) {ids.emplace_back(class_ids[nms_result[i]]);confs.emplace_back(confidences[nms_result[i]]);rects.emplace_back(boxes[nms_result[i]]);}draw_boxes(classes, ids, confs, rects, name, frame);
}} // namespaceint test_yolov8_detect_onnxruntime()
{// reference: ultralytics/examples/YOLOv8-ONNXRuntime-CPPtry {Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "Yolo");Ort::SessionOptions session_option;if (cuda_enabled) {OrtCUDAProviderOptions cuda_option;cuda_option.device_id = 0;session_option.AppendExecutionProvider_CUDA(cuda_option);}session_option.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);session_option.SetIntraOpNumThreads(1);session_option.SetLogSeverityLevel(3);Ort::Session session(env, ctow(onnx_file).c_str(), session_option);Ort::AllocatorWithDefaultOptions allocator;std::vector<const char*> input_node_names, output_node_names;std::vector<std::string> input_node_names_, output_node_names_;for (auto i = 0; i < session.GetInputCount(); ++i) {Ort::AllocatedStringPtr input_node_name = session.GetInputNameAllocated(i, allocator);input_node_names_.emplace_back(input_node_name.get());}for (auto i = 0; i < session.GetOutputCount(); ++i) {Ort::AllocatedStringPtr output_node_name = session.GetOutputNameAllocated(i, allocator);output_node_names_.emplace_back(output_node_name.get());}for (auto i = 0; i < input_node_names_.size(); ++i)input_node_names.emplace_back(input_node_names_[i].c_str());for (auto i = 0; i < output_node_names_.size(); ++i)output_node_names.emplace_back(output_node_names_[i].c_str());Ort::RunOptions options(nullptr);std::unique_ptr<float[]> blob(new float[image_size[0] * image_size[1] * 3]);std::vector<int64_t> input_node_dims{ 1, 3, image_size[1], image_size[0] };auto classes = parse_classes_file(classes_file);if (classes.size() == 0) {std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;return -1;}for (const auto& [key, val] : get_dir_images(images_dir)) {cv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);if (frame.empty()) {std::cerr << "Warning: unable to load image: " << val << std::endl;continue;}auto tstart = std::chrono::high_resolution_clock::now();cv::Mat rgb;auto resize_scales = image_preprocess(frame, rgb);image_to_blob(rgb, blob.get());Ort::Value input_tensor = Ort::Value::CreateTensor<float>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob.get(), 3 * image_size[1] * image_size[0], input_node_dims.data(), input_node_dims.size());auto output_tensors = session.Run(options, input_node_names.data(), &input_tensor, 1, output_node_names.data(), output_node_names.size());Ort::TypeInfo type_info = output_tensors.front().GetTypeInfo();auto tensor_info = type_info.GetTensorTypeAndShapeInfo();std::vector<int64_t> output_node_dims = tensor_info.GetShape();auto output = output_tensors.front().GetTensorMutableData<float>();int stride_num = output_node_dims[1];int signal_result_num = output_node_dims[2];cv::Mat raw_data = cv::Mat(stride_num, signal_result_num, CV_32F, output);raw_data = raw_data.t();float* data = (float*)raw_data.data;auto tend = std::chrono::high_resolution_clock::now();std::cout << "elapsed millisenconds: " << std::chrono::duration_cast<std::chrono::milliseconds>(tend - tstart).count() << " ms" << std::endl;post_process(data, signal_result_num, stride_num, resize_scales, resize_scales, classes, frame, key);}}catch (const std::exception& e) {std::cerr << "Error: " << e.what() << std::endl;return -1;}return 0;
}

      labels.txt文件内容如下:仅2类

watermelon 0
wintermelon 1

      说明

      1.这里使用的onnxruntime版本为1.18.0;

      2.windows下,onnxruntime库在debug和release为同一套库,在debug和release下均可执行;

      3.通过指定变量cuda_enabled判断走cpu还是gpu流程 ;

      4.windows下,onnxruntime中有些接口参数为wchar_t*,而linux下为char*,因此在windows下需要单独做转换,这里通过ctow函数实现从char*到wchar_t的转换;

      5.yolov8中提供的sample有问题,需要作调整。

      执行结果如下图所示:同样的预测图像集,与opencv dnn结果相似,它们具有相同的后处理流程;下面显示的耗时是在cpu下,gpu下仅20毫秒左右

      其中一幅图像的检测结果如下图所示:

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337828.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++的第一道门坎:类与对象(二)

目录 一.类中生成的默认成员函数详解 0.类的6个默认成员函数 1.构造函数 1.1概念 1.2特性 2.析构函数 2.1概念 2.2特性 3.拷贝构造函数 3.1概念 3.2特性 3.3拷贝构造的使用方法 4.运算符重载 5.赋值运算符重载 6.const修饰函数 7.取地址及const取地址操作符重载…

【漯河市人才交流中心_登录安全分析报告-Ajax泄漏滑动距离导致安全隐患】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…

Windows10(家庭版)中DockerDesktop(docker)的配置、安装、修改镜像源、使用

场景 Windows10中Docker的安装与遇到的那些坑: Windows10中Docker的安装与遇到的那些坑_在 docker.core.logging.httpclientexceptionintercept-CSDN博客 上面讲Docker Desktop在windows10非家庭版上的安装&#xff0c;如果是家庭版&#xff0c;则需要执行如下步骤。 注&am…

【python】OpenCV—Tracking(10.2)

文章目录 BackgroundSubtractorcreateBackgroundSubtractorMOG2createBackgroundSubtractorKNN BackgroundSubtractor Opencv 有三种背景分割器 K-Nearest&#xff1a;KNN Mixture of Gaussian&#xff08;MOG2&#xff09; Geometric Multigid&#xff08;GMG&#xff09; …

K210视觉识别模块学习笔记4: 训练与使用自己的模型_识别字母

今日开始学习K210视觉识别模块: 模型训练与使用_识别字母 亚博智能的K210视觉识别模块...... 固件库: maixpy_v0.6.2_52_gb1a1c5c5d_minimum_with_ide_support.bin 文章提供测试代码讲解、完整代码贴出、测试效果图、测试工程下载 这里也算是正式开始进入到视觉识别的领域了…

matplotlib ---词云图

词云图是一种直观的方式来展示文本数据&#xff0c;可以体现出一个文本中词频的使用情况&#xff0c;有利于文本分析&#xff0c;通过词频可以抓住一篇文章的重点 本文通过处理一篇关于分析影响洋流流向的文章&#xff0c;分析影响洋流流向的主要因素都有哪些 文本在文末结尾 …

基于Freertos的工训机器人

一. 工训机器人 V1 1. 实物 将自制的F4开发板放置车底板下方&#xff0c;节省上方空间&#xff0c;且能保证布线方便整齐。 2. SW仿真 使用SolidWorks进行仿真&#xff0c;且绘制3D打印件。 工训仿真 3.3D打印爪测试 机械爪测试 二. 工训机器人 V2 1. 实物 工训机器人V2不同于…

IDEA 打开项目后看不到项目结构怎么办?

1、先把这个项目从 IDEA 中移除 2、再重新打开或导入 3、如果还没有解决&#xff0c;就先把这个项目拷贝出来把原来的路径上的项目给删除&#xff0c;然后再把拷贝后的项目放在一个路径下&#xff0c;再打开就可以了

沟通程序化(1):跟着鬼谷子学沟通—“飞箝”之术

沟通的基础需要倾听&#xff0c;但如果对方听不进你的话&#xff0c;即便你说的再有道理&#xff0c;对方也很难入心。让我们看看鬼谷子的“飞箝”之术能给我们带来什么样的启发吧&#xff01; “飞箝”之术&#xff0c;源自中国古代兵法家、纵横家鼻祖鬼谷子的智慧&#xff0…

​LabVIEW超声波检测

LabVIEW超声波检测 在现代工业生产和科学研究中&#xff0c;超声检测技术因其无损性、高效率和可靠性而被广泛应用于材料和结构的缺陷检测。然而&#xff0c;传统的超声检测仪器往往依赖于操作者的经验和技能&#xff0c;其检测过程不够智能化&#xff0c;且检测结果的解读具有…

Appium系列(2)元素定位工具appium-inspector

背景 如实现移动端自动化&#xff0c;依赖任何工具时&#xff0c;都需要针对于页面中的元素进行识别&#xff0c;通过识别到指定的元素&#xff0c;对元素进行事件操作。 识别元素的工具为appium官网提供的appium-inspector。 appium-inspector下载地址 我这里是mac电脑需要下…

使用Python突破网站验证码限制

之前有小伙伴说&#xff0c;在web自动化的过程中&#xff0c;经常会被登录的验证码给卡住&#xff0c;不知道如何去通过验证码的验证&#xff0c;今天专门给大家来聊聊验证码的问题。 常见的验证码一般分为两类&#xff0c;一类是图文验证码&#xff0c;一类是滑块验证码&#…

vue2+antv/x6实现er图

效果图 安装依赖 npm install antv/x6 --save 我目前的项目安装的版本是antv/x6 2.18.1 人狠话不多&#xff0c;直接上代码 <template><div class"er-graph-container"><!-- 画布容器 --><div ref"graphContainerRef" id"gr…

SpringCloud如何实现SSO单点登录?

目录 一、SpringCloud框架介绍 二、什么是SSO单点登录 三、单点登录的必要性 四、SpringCloud如何实现SSO单点登录 一、SpringCloud框架介绍 Spring Cloud是一个基于Spring Boot的微服务架构开发工具集&#xff0c;它整合了多种微服务解决方案&#xff0c;如服务发现、配置…

es的总结

es的collapse es的collapse只能针对一个字段聚合&#xff08;针对大数据量去重&#xff09;&#xff0c;如果以age为聚合字段&#xff0c;则会展示第一条数据&#xff0c;如果需要展示多个字段&#xff0c;需要创建新的字段&#xff0c;如下 POST testleh/_update_by_query {…

C#WPF数字大屏项目实战07--当日产量

1、第2列布局 第2列分三行&#xff0c;第一行分6列 2、当日产量布局 3、产量数据布局 运行效果 4、计划产量和完成度 运行效果 5、良品率布局 1、添加用户控件 2、用户控件绘制圆 2、使用用户控件 3、运行效果 4、注意点 这三个数值目前是静态的&#xff0c;可以由后台程序项…

构建高效稳定的短视频直播系统架构

随着短视频直播的迅猛发展&#xff0c;构建一个高效稳定的短视频直播系统架构成为了互联网企业的重要挑战。本文将探讨如何构建高效稳定的短视频直播系统架构&#xff0c;以提供优质的用户体验和满足日益增长的用户需求。 ### 1. 短视频直播系统的背景 短视频直播近年来蓬勃发…

ARM-V9 RME(Realm Management Extension)系统架构之系统安全能力的信任根服务

安全之安全(security)博客目录导读 目录 一、信任根服务 1、非易失性存储 2、根看门狗 3、随机数生成器 4、加密服务 5、硬件强制安全性 本节定义了系统架构必须支持的一般安全属性和能力&#xff0c;以确保RME安全性。 本章扩展了可能属于系统认证配置文件的一部分的其…

k8s之PV、PVC

文章目录 k8s之PV、PVC一、存储卷1、存储卷定义2、存储卷的作用2.1 数据持久化2.2 数据共享2.3 解耦2.4 灵活性 3、存储卷的分类3.1 emptyDir存储卷3.1.1 定义3.1.2 特点3.1.3 用途3.1.4 示例 3.2 hostPath存储卷3.2.1 定义3.2.2 特点3.2.3 用途3.2.4 示例 3.3 NFS存储卷3.3.1 …

SQL数据库多表创建之一对多、多对多表创建

MySQL多表创建关联及操作_mysql创建关联表-CSDN博客文章浏览阅读1.1k次&#xff0c;点赞21次&#xff0c;收藏20次。表与表之间的关系表语表之间的关系&#xff0c;说的就是表与表数据之间的关系。_mysql创建关联表https://blog.csdn.net/2401_83641392/article/details/137031…