【RKNN】YOLO V5中pytorch2onnx,pytorch和onnx模型输出不一致,精度降低

yolo v5训练的模型,转onnx,再转rknn后,测试发现:

  1. rknn模型,量化与非量化,相较于pytorch模型,测试精度都有降低
  2. onnx模型,相较于pytorch模型,测试精度也有降低,且与rknn模型的精度更接近

于是,根据这种测试情况,rknn模型的上游,就是onnx。onnx这里发现不对劲,肯定是这步就出现了问题。于是就查pytorch转onnx阶段,就存在转化的精度降低了。

本篇就是记录这样一个过程,也请各位针对本文的问题,给一些建议,毕竟目前是发现了问题,同时还存在一些问题在。

一、pytorch转onnx:torch.onnx.export

yolo v5 export.py: def export_onnx()中,添加下面代码,检查转储的onnx模型,与pytorch模型的输出结果是否一致。代码如下:

torch.onnx.export(model.cpu() if dynamic else model,  # --dynamic only compatible with cpuim.cpu() if dynamic else im,f,verbose=False,opset_version=opset,export_params=True, # 将训练好的权重保存到模型文件中do_constant_folding=True,  # 执行常数折叠进行优化input_names=['images'],output_names=output_names,dynamic_axes={"image": {0: "batch_size"},  # variable length axes"output": {0: "batch_size"},}
)# Checks
model_onnx = onnx.load(f)  # load onnx model
onnx.checker.check_model(model_onnx)  # check onnx modelimport onnxruntime
import numpy as np
print('onnxruntime run start', f)
sess = onnxruntime.InferenceSession('best.onnx')
print('sess run start')
output = sess.run(['output0'], {'images': im.detach().numpy()})[0]
print('pytorch model inference start')pytorch_result = model(im)[0].detach().numpy()
print(' allclose start')
print('output:', output)
print('pytorch_result:', pytorch_result)
assert np.allclose(output, pytorch_result), 'the output is different between pytorch and onnx !!!'

对其中的输出结果进行了打印,将差异性比较明显的地方进行了标记,如下所示:

在这里插入图片描述
也可以直接使用我下面这个版本,在转完onnx后,进行评测,转好的onnx和pt文件之间的差异性。如下:

参考pytorch官方:(OPTIONAL) EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME

import os
import platform
import sys
import warnings
from pathlib import Path
import torchFILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:sys.path.append(str(ROOT))  # add ROOT to PATH
if platform.system() != 'Windows':ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relativefrom models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_modeimport numpy as np
def cosine_distance(arr1, arr2):# flatten the arrays to shape (16128, 7)arr1_flat = arr1.reshape(-1, 7)arr2_flat = arr2.reshape(-1, 7)# calculate the cosine distancecosine_distance = np.dot(arr1_flat.T, arr2_flat) / (np.linalg.norm(arr1_flat) * np.linalg.norm(arr2_flat))return cosine_distance.mean()def check_onnx(model, im):import onnxruntimeimport numpy as npprint('onnxruntime run start')sess = onnxruntime.InferenceSession('best.onnx')print('sess run start')output = sess.run(['output0'], {'images': im.detach().numpy()})[0]print('pytorch model inference start')with torch.no_grad():pytorch_result = model(im)[0].detach().numpy()print(' allclose start')print('output:', output, output.shape)print('pytorch_result:', pytorch_result, pytorch_result.shape)cosine_dis = cosine_distance(output, pytorch_result)print('cosine_dis:', cosine_dis)# 判断小数点后几位(4),是否相等,不相等就报错# np.testing.assert_almost_equal(pytorch_result, output, decimal=4)# compare ONNX Runtime and PyTorch resultsnp.testing.assert_allclose(pytorch_result, output, rtol=1e-03, atol=1e-05)# assert np.allclose(output, pytorch_result), 'the output is different between pytorch and onnx !!!'import cv2
from utils.augmentations import letterbox
def preprocess(img, device):img = cv2.resize(img, (512, 512))img = img.transpose((2, 0, 1))[::-1]img = np.ascontiguousarray(img)img = torch.from_numpy(img).to(device)img = img.float()img /= 255if len(img.shape) == 3:img = img[None]return img
def main(weights=ROOT / 'weights/best.pt',  # weights pathimgsz=(512, 512),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinplace=False,  # set YOLOv5 Detect() inplace=Truedynamic=False,  # ONNX/TF/TensorRT: dynamic axes):# Load PyTorch modeldevice = select_device(device)model = attempt_load(weights, device=device, inplace=True, fuse=True)  # load FP32 model# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expand# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# im = cv2.imread(r'F:\tmp\yolov5_multiDR\data\0000005_20200929_M_063Y16640.jpeg')# im = preprocess(im, device)print(im.shape)# Update modelmodel.eval()for k, m in model.named_modules():if isinstance(m, Detect):m.inplace = inplacem.dynamic = dynamicm.export = Truewarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningcheck_onnx(model, im)if __name__ == "__main__":main()

测试1:图像是一个全0的数组,一致性检查如下:

Mismatched elements: 76 / 112896 (0.0673%)
Max absolute difference:  0.00053406
Max relative difference:      2.2101output: [[[     3.1054       3.965      8.9553 ...  6.8545e-07     0.36458     0.53113][     9.0205      2.5498       13.39 ...  6.2585e-07     0.18449     0.70698][     20.786      2.2233      13.489 ...  2.3842e-06    0.033101     0.95657]...[     419.42      493.04      106.14 ...  8.4937e-06     0.24135     0.60916][     485.68      500.22      46.923 ...  1.1176e-05     0.33573     0.48875][     488.37      503.87      68.881 ...  5.9605e-08  0.00030029     0.99639]]] (1, 16128, 7)
pytorch_result: [[[     3.1054       3.965      8.9553 ...  7.0523e-07     0.36458     0.53113][     9.0205      2.5498       13.39 ...  6.0181e-07     0.18449     0.70698][     20.786      2.2233      13.489 ...  2.4172e-06    0.033101     0.95657]...[     419.42      493.04      106.14 ...  8.5151e-06     0.24135     0.60916][     485.68      500.22      46.923 ...  1.1174e-05     0.33573     0.48875][     488.37      503.87      68.881 ...  9.3094e-08   0.0003003     0.99639]]] (1, 16128, 7)
cosine_dis: 0.04229331

测试2:图像是加载的本地图像,一致性检查如下:

Mismatched elements: 158 / 112896 (0.14%)
Max absolute difference:   0.0016251
Max relative difference:      1.2584output: [[[     3.0569      2.4338      10.758 ...  2.0862e-07     0.16333     0.78551][     11.028      2.0251      13.407 ...  3.5763e-07    0.090503     0.88087][     19.447      1.8957      13.431 ...  6.8545e-07    0.047358     0.95029]...[     418.66       487.8      80.157 ...  1.4573e-05     0.65453     0.23448][     472.99      491.78      79.313 ...  1.3232e-05     0.79356     0.15061][     496.41      488.49      44.447 ...  2.6256e-05     0.89966     0.08772]]] (1, 16128, 7)
pytorch_result: [[[     3.0569      2.4338      10.758 ...  2.5371e-07     0.16333     0.78551][     11.028      2.0251      13.407 ...  3.3069e-07    0.090503     0.88087][     19.447      1.8957      13.431 ...  6.6051e-07    0.047358     0.95029]...[     418.66       487.8      80.157 ...  1.4618e-05     0.65453     0.23448][     472.99      491.78      79.313 ...  1.3215e-05     0.79356     0.15061][     496.41      488.49      44.447 ...  2.6262e-05     0.89966     0.08772]]] (1, 16128, 7)
cosine_dis: 0.04071107

发现,输出结果中,差异的数据点还是挺多的,那么就说明在模型中,有些部分的参数是有差异的,这才导致相同的输入,在最后的输出结果中存在差异。

但是在一定的误差内,结果是一致的。比如我验证了小数点后3位,都是一样的,但是到第4位的时候,就开始出现了差异性。

那么,如何降低,甚至没有这种差异,该怎么办呢?不知道你们有没有这方面的知识储备或经验,欢迎评论区给出指导,感谢。

二、新的pytorch转onnx:torch.onnx.dynamo_export

在参考pytorch官方,关于torch.onnx.export的模型转换,相关文档中:(OPTIONAL) EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME

1
上述案例,是pytorch官方给出评测pytorch和onnx转出模型,在相同输入的情况下,输出结果一致性对比的评测代码。对比这里:

testing.assert_allclose(actual, desired, rtol=1e-07, atol=0, equal_nan=True, err_msg='', verbose=True)

其中:

  • rtol:相对tolerance(容忍度,公差,容许偏差)
  • atol:绝对tolerance
  • 要求 actualdesired 值的差别不超过 atol + rtol * abs(desired),否则弹出错误提示

可以看出,这是在误差允许的范围内,进行的评测。只要满足一定的误差要求,还是满足的。并且在本测试案例中,也确实通过了上述设定值的误差要求。

但是,峰回路转,有个提示,如下:
2
于是,就转到torch.onnx.dynamo_export链接,点击这里直达:EXPORT A PYTORCH MODEL TO ONNX

同样的流程,导出模型,然后进行一致性评价,发现官方竟然没有采用允许误差的评测,而是下面这样:
在这里插入图片描述输出完全一致,这是一个大好消息。至此,开始验证

2.1、验证结果

与此同时,发现yolo v5更新到了v7.0.0的版本,于是就想着把yolo 进行升级,同时将pytorch版本也更新到最新的2.1.0,这样就可以采用torch.onnx.dynamo_export 进行转onnx模型的操作尝试了。

当一起就绪后,采用下面的代码转出onnx模型的时候,却出现了错误提示。

export_output = torch.onnx.dynamo_export(model.cpu() if dynamic else model,im.cpu() if dynamic else im)
export_output.save("my_image_classifier.onnx")

2.2、转出失败

在这里插入图片描述

给出失败的的提示:torch.onnx.OnnxExporterError,转出onnx模型失败,产生了一个SARIF的文件。然后介绍了什么是SARIF文件,可以通过VS Code SARIF,也可以 SARIF web查看。最后说吧这个错误,报告给pytorchGitHubissue地方。

产生了一个名为:report_dynamo_export.sarif是文件,打开文件,记录的信息如下:

{"runs":[{"tool":{"driver":{"name":"torch.onnx.dynamo_export","contents":["localizedData","nonLocalizedData"],"language":"en-US","rules":[],"version":"2.1.0+cu118"}},"language":"en-US","newlineSequences":["\r\n","\n"],"results":[]}],"version":"2.1.0","schemaUri":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
}

这更像是一个运行环境收集的一个记录文件。在我对全网进行搜索时候,发现了类似的报错提示,但并没有解决办法。不知道是不是因为这个函数还在内测阶段,并没有很好的适配。

如果你也遇到了同样的问题,欢迎给评论,指导问题出在了哪里?如何解决这个问题。感谢

三、总结

原本想着验证最终转rknn的模型,与原始pytorch模型是否一致的问题,最后发现在转onnx阶段,这种差异性就已经存在了。并且发现rknn的测试结果,与onnx模型的测试结果更加的贴近。无论是量化后的rknn,还是未量化的,均存在这个问题。

同时发现,量化后的rknn模型,在config阶段改变量化的方式,确实会提升模型的性能,且几乎接近于未量化的模型版本。

原本以为采用pytorch新的转出onnx的模型函数,可以解决这个问题。但是,发现还是内测版本,不知道问题是出在了哪里,还需要大神帮助,暂时未跑通。

最后,如果你也遇到了同样的问题,欢迎给评论,指导问题出在了哪里?如何解决这个问题。感谢

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/156914.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

缓存设计的创新之旅:架构的灵魂之一

缓存在架构设计中占有重要地位。缓存在提升性能中也扮演重要的角色。常见的有对资源的缓存,比如数据库连接池、http连接池,还有对数据的缓存等。缓存的设计可复杂也可简单,但是需要考虑的点却很多。 缓存对象 设计缓存的时候一定要考虑的是&…

大语言模型之十七-QA-LoRA

由于基座模型通常需要海量的数据和算力内存,这一巨大的成本往往只有巨头公司会投入,所以一些优秀的大语言模型要么是大公司开源的,要么是背后有大公司身影公司开源的,如何从优秀的开源基座模型针对特定场景fine-tune模型具有广大的…

香港专用服务器拥有良好的国际网络连接

香港服务器在多个领域有着广泛的应用。无论是电子商务、金融交易、游戏娱乐还是社交媒体等,香港服务器都能够提供高效稳定的服务。对于跨境电商来说,搭建香港服务器可以更好地满足亚洲用户的购物需求;对于金融机构来说,香港服务器…

当涉及到API接口数据分析时,主要可以从以下几个方面展开

当涉及到API接口数据分析时,主要可以从以下几个方面展开: 请求分析:可以统计每个API接口的请求次数、请求成功率、失败率等基础指标。这些指标可以帮助你了解API接口的使用情况,比如哪个API接口被调用的次数最多,哪个…

c++-list

文章目录 前言一、list介绍及使用1、list介绍2、list使用2.1 list构造函数的使用2.2 list iterator的使用2.3 list capacity的使用2.4 list modifiers的使用2.5 list使用算法库中的find模板生成find方法2.6 list中的sort方法 二、list模拟实现1、查看list源码的大致实现思路2、…

身份证实名核验接口,身份证实名认证,身份证二要素实名认证,身份证实名校验,身份证一致性实名认证

一、接口介绍 验证身份证与姓名是否匹配,查询身份证信息。如校验通过,接口返回生日、性别、地址等信息。广泛应用于信贷、安防、银行、保险等行业及各种身份核查场景。 注意:当请求参数符合“【固定同一个参数,其余参数不同】,”…

基于VScode 使用plantUML 插件设计状态机

本文主要记录本人初次在VScode上使用PlantUML设计 本文只讲述操作的实际方法,假设java已安装成功 。 1. 在VScode下安装如下插件 2. 验证环境是否正常 新建一个文件夹并在目录下面新建文件test.plantuml 其内容如下所示: startuml hello world skinparam Style …

基于小波变换的分形信号r指数求解算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ................................................................... %通过功率谱密度曲线…

WebSocket连接异常 Error parsing HTTP request header Connection reset by peer

问题描述 在使用spring的方式集成websocket时,在配置WebSocketConfigurer后 Configuration EnableWebSocket public class WebSocketConfiguration implements WebSocketConfigurer {ResourceServletWebSocketServerHandler servletWebSocketServerHandler;Overri…

linux总结

cat -n filename 查看文件,-n用来给每一行标行号,可以省略 cat /var/log/mysqld.log | grep password 我们可以通过上述指令,查询日志文件内容中包含password的行信息。 more 作用: 以分页的形式显示文件内容 语法: more fileName 操作说明: 回车键 …

Spring Boot 中的 Redis 数据操作配置和使用

Spring Boot 中的 Redis 数据操作配置和使用 Redis(Remote Dictionary Server)是一种高性能的开源内存数据库,用于缓存、消息队列、会话管理和数据存储。在Spring Boot应用程序中,Redis被广泛用于各种用例,包括缓存、…

【教学类-35-04】学号+姓名+班级(中3班)学号字帖(A4竖版2份 竖版长条)

图片展示: 背景需求: 2022年9-2023年1月我去过小3班带班,但是没有在这个班级投放过学具,本周五是我在本学期第一次带中3班,所以提供了一套学号描字帖。先让我把孩子的名字和脸混个眼熟。 之前试过一页两套名字的纸张切割方法有:…

distcc分布式编译

distcc https://gitee.com/bison-fork/distcc.git 下载工具链 mingw,https://www.mingw-w64.org/downloads/#w64devkitperl,https://strawberryperl.com/releases.html免安装zip版本,autoconf等脚本依赖perlautoconf、automake&#xff0c…

只有正规才有机会,CTF/AWD竞赛标准参考书来了

目录 前言 一、内容简介 二、读者对象 三、目录 前言 随着网络安全问题日益凸显,国家对网络安全人才的需求持续增长,其中,网络安全竞赛在国家以及企业的人才培养和选拔中扮演着至关重要的角色。 在数字化时代,企业为了应对日益…

Flutter:open_file打开本地文件报错问题

相关插件及版本: open_file: ^3.2.1 问题: 项目中一直用的这个插件,突然发现在安卓高版本不能正常使用,报权限问题permissionDenied,断点调试提示相关权限是MANAGE_EXTERNAL_STORAGE,申请权限之后还是不行&…

springboot集成kafka

1、引入依赖 <dependency><groupId>org.springframework.kafka</groupId><artifactId>spring-kafka</artifactId><version>2.8.6</version></dependency> 2、配置 server:port: 9099 spring:kafka:bootstrap-servers: 192.1…

JWT - 令牌认证授权(认证流程、认证原理、Jwt 工具类)

目录 一、JWT 认证 1.1、对 JWT 的认识 1.1.1、JWT 解释 1.1.2、为什么使用的 JWT 认证&#xff0c;而不是 Session 认证&#xff1f; a&#xff09;基于传统的 Session 认证 1.1.3、JWT 认证流程 1.1.4、优势 1.1.5、JWT 的结构 JWT 第一部分&#xff1a;标头 Header …

LeetCode - 318 最大单词长度乘积(Java JS Py C)

目录 题目来源 题目描述 示例 提示 题目解析 算法源码 题目来源 318. 最大单词长度乘积 - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你一个字符串数组 words &#xff0c;找出并返回 length(words[i]) * length(words[j]) 的最大值&#xff0c;并且这两个单词…

RocketMQ核心编程模型以及生产环境最佳实践

文章目录 一、深入理解RocketMQ的消息模型二、消息确认机制消息生产端采用消息确认加多次重试的机制保证消息正常发送到RocketMQ消息消费者端采用状态确认机制保证消费者一定能正常处理对应的消息消费者也可以自行指定起始消费位点 三、广播消息四、顺序消息机制五、延迟消息六…

【Mybatis】动态 SQL

动态 SQL \<if>标签\<trim>标签\<where>标签\<set>标签\<foreach>标签 动态 sql 是 Mybatis 的强⼤特性之⼀&#xff0c;能够完成不同条件下不同的 sql 拼接。 <if>标签 前端用户输入时有些选项是非必填的, 那么此时传到后端的参数是不确…