6.5.tensorRT高级(1)-alphapose模型导出、编译到推理(无封装)

目录

    • 前言
    • 1. alphapose导出
    • 2. alphapose推理
    • 3. 讨论
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 高级-alphapose模型导出、编译到推理(无封装)

课程大纲可看下面的思维导图

在这里插入图片描述

1. alphapose导出

这节课我们来学习 alphapose 姿态点估计

我们依旧从学习者的角度从零开始,拉取官方代码修改并正常导出 onnx,这个项目的复杂度较高,我们先看官方代码(下载于2022/3/27),使用的模型是 Fast Pose,如下图所示,由于 DCN 写插件比较麻烦,所以没有选择带 DCN的模型(目前 tensorRT_Pro 项目支持 DCN 算子)

在这里插入图片描述

图1-1 模型选择

先去执行 srcipts 脚本文件夹下的 demo_inference.py,看能否正常推理,执行如下:

在这里插入图片描述

图1-2 问题1

可以看到出现 No module named ‘detector’ 错误,执行的 python 文件位于 scripts 文件夹下,而 detector 与 scripts 同级,有两种解决方式:

第一种就是 sys.path.insert() 插入 detector 的路径

第二种更加推荐就是 export PYTHONPATH=.

解决后再去执行,又报错了,如下所示:

在这里插入图片描述

图1-3 问题2

出现 cython_bbox 错误,该模块主要用于 tracker 的,我们不需要 tracker 所以暂时屏蔽它,屏蔽后接着去执行,又报错了,如下所示:

在这里插入图片描述

图1-4 问题3

出现 roi_align_cuda 错误,先屏蔽搁置后续如果需要再说,所以在 simple_transform.py 中屏蔽 21 和 22 行,再去执行,如下所示:

在这里插入图片描述

图1-5 问题4

成功了,剩下的就是来提供参数,建议写一个脚本来输入参数,不要老是在终端去敲命令,infer.sh 内容如下所示:

#!/bin/bashexport PYTHONPATH=.
python scripts/demo_inference.py \--cfg=configs/halpe_136/resnet/256x192_res50_lr1e-3_2x-regression.yaml \--checkpoint=pretrained_models/multi_domain_fast50_regression_256x192.pth \--sp \--image=examples/demo/1.jpg \--save_img

bash 执行下,又报错了,如下所示:

在这里插入图片描述

图1-6 问题5

nms 有两套实现,我们直接屏蔽掉报错的那部分,直接使用手动实现,具体是 yolo_api.py 第 192 行,修改内容如下:

#if nms has to be done
if nms:# if platform.system() != 'Windows':#     #We use faster rcnn implementation of nms (soft nms is optional)#     nms_op = getattr(nms_wrapper, 'nms')#     #nms_op input:(n,(x1,y1,x2,y2,c))#     #nms_op output: input[inds,:], inds#     _, inds = nms_op(image_pred_class[:,:5], nms_conf)#     image_pred_class = image_pred_class[inds]# else:# Perform non-maximum suppressionmax_detections = []while image_pred_class.size(0):# Get detection with highest confidence and save as max detectionmax_detections.append(image_pred_class[0].unsqueeze(0))# Stop if we're at the last detectionif len(image_pred_class) == 1:break# Get the IOUs for all boxes with lower confidenceious = bbox_iou(max_detections[-1], image_pred_class[1:], args)# Remove detections with IoU >= NMS thresholdimage_pred_class = image_pred_class[1:][ious < nms_conf]image_pred_class = torch.cat(max_detections).data

然后屏蔽掉 yolo_api.py 的第 26 和 27 行,再去执行,又报错了,如下所示:

在这里插入图片描述

图1-7 问题6

RoIAlign 没有定义,我们直接置为空,在 simple_transform.py 中的第 80 行修改为如下内容:

if platform.system() != 'Windows':self.roi_align = None #RoIAlign(self._input_size, sample_num=-1)if gpu_device is not None:self.roi_align = self.roi_align.to(gpu_device)

再去执行,又报错了,如下所示:

在这里插入图片描述

图1-8 问题7

接着去屏蔽,如下所示:

if platform.system() != 'Windows':self.roi_align = None #RoIAlign(self._input_size, sample_num=-1)# if gpu_device is not None:#     self.roi_align = self.roi_align.to(gpu_device)

那你可能会问为什么屏蔽?那凭感觉,通过理解罢了(还是得对整个流程熟悉呀😄)

再次执行,又报错了,如下所示:

在这里插入图片描述

图1-9 问题8

上述问题是由于 pytorch 的模型名字不配对导致的,都是官方提供的但是报错了,说明官方没有做足够的 debug 才有这一堆破事,可以看到 checkpoint 的 shape 是 512,而 model 的 shape 是 1024,这说明刚才指定的配置文件和模型不匹配

我们来看下模型的定义,在 fastpost.py 中的第 44 行,错误提示 duc2.conv.weight 即 checkpoint 的 shape 是 512x256,而模型是 1024x256,还有 conv_out.weight 即 checkpoint 的 shape 是 136x128,而模型是 136x256,所以最终我们要把控制条件 slef.conv_dim 从 256 修改为 128,那这个变量是由 yaml 文件来控制的

因此我们需要从 256x192_res50_lr1e-3_2x-regression.yaml 文件中找到 CONV_DIM,将其从 256 修改为 128,再次执行,如下所示

在这里插入图片描述

图1-10 成功运行

成功了,说明我们修改的地方是正确的,模型推理的效果如下所示:

在这里插入图片描述

图1-11 模型推理效果

能正常推理了,接下来就是要把它正确的导出来了,在正式导出之前我们需要自己手动实现下推理过程,因此写一个 predict.py ,需要把它的整个推理过程像 Unet 一样抽出来,怎么去抽呢?官方推理都一堆 bug,这个事情就显得有些繁琐了

我们主要还是去参考 demo_inference.py 中的内容,根据各种分析最后得到的 predict.py 内容如下:

import yaml
from easydict import EasyDict as edict
from alphapose.models import builder
import torch
import numpy as np
import cv2def update_config(config_file):with open(config_file) as f:config = edict(yaml.load(f, Loader=yaml.FullLoader))return configclass MySPPE(torch.nn.Module):def __init__(self):super().__init__()checkpoint = "pretrained_models/multi_domain_fast50_regression_256x192.pth"cfg = update_config("configs/halpe_136/resnet/256x192_res50_lr1e-3_2x-regression.yaml")self.pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)self.pose_model.load_state_dict(torch.load(checkpoint, map_location="cpu"))def forward(self, x):hm = self.pose_model(x)stride = int(256 / hm.size(2))b, c, h, w = map(int, hm.size())prob = hm.sigmoid()confidence, _ = prob.view(-1, c, h * w).max(dim=2, keepdim=True)prob = prob / prob.sum(dim=[2, 3], keepdim=True)coordx = torch.arange(w, device=prob.device, dtype=torch.float32)coordy = torch.arange(h, device=prob.device, dtype=torch.float32)hmx = (prob.sum(dim=2) * coordx).sum(dim=2, keepdim=True) * stridehmy = (prob.sum(dim=3) * coordy).sum(dim=2, keepdim=True) * stridekeypoint = torch.cat([hmx, hmy, confidence], dim=2)return keypointmodel = MySPPE().eval()x, y, w, h = 158, 104, 176, 693
image = cv2.imread("gril.jpg")[y:y+h, x:x+w]
image = image[..., ::-1]
image = cv2.resize(image, (256, 192))
image = ((image / 255.0) - [0.406, 0.457, 0.480]).astype(np.float32)
image = image.transpose(2, 0, 1)[None]
image = torch.from_numpy(image)with torch.no_grad():keypoint = model(image)print(keypoint.shape)
#return torch.cat([hmx, hmy, confidence], dim=2)dummy = torch.zeros(1, 3, 256, 192)
torch.onnx.export(model, (dummy,), "fastpose.onnx", input_names=["image"], output_names=["predict"], opset_version=11, dynamic_axes={"image": {0:"batch"}, "predict": {0:"batch"}}
)
print("Done")

杜老师通过分析把预处理和后处理给抽出来了,这要是自己分析不得疯,主要是代码封装得太深了,alphapose 的预处理部分在 simple_transform.py 文件中的 test_transform() 函数,输入一张原图和一个 box,进行相关预处理

整个预处理过程就是把 box 抠出来,然后移到中间,再减去均值就结束了。后处理部分是在 transforms.py 中的 heatmap_to_coord_simple_regress() 函数中实现的

值的注意的是我们在这里只是为了推理演示,去除了部分操作,我们直接拿一个已有的 box 塞到网络中去,省去了检测器的部分,同时拿到 box 后其实还是要做仿射变换的,这里为了方便直接使用的 resize,

模型推理结果是 136 维度的 heatmap,后处理就是是将 heatmap 变成回归值的过程,主要是得到我们的关键点坐标,这里把后处理部分也直接塞到 onnx 中,避免提高在 tensorRT 中的复杂度,

执行下 predict.py 如下所示:

在这里插入图片描述

图1-12 执行predict

导出的 onnx 如下图:

在这里插入图片描述

图1-13 onnx

可以看到模型有很多多余的节点,都是 view 造成的,我们需要去除,在 SE_module.py 中 forward 部分修改,修改内容如下:

def forward(self, x):b, c, _, _ = x.size()# y = self.avg_pool(x).view(b, c)y = self.avg_pool(x).view(-1, int(c))# y = self.fc(y).view(b, c, 1, 1)y = self.fc(y).view(-1, int(c), 1, 1)return x * y

再导出下,onnx 如下图所示:

在这里插入图片描述

图1-14 onnx1

可以看到非常干净,是我们想要的效果

其实我们也可以直接拿 onnxsim 优化下,如下图所示:

import onnx
from onnxsim import simplifyonnx_model = onnx.load("fastpose.onnx")
model_simp, check = simplify(onnx_model)onnx.save(model_simp, "fastpose.sim.onnx")

导出的 fastpose.sim.onnx 如下图所示:

在这里插入图片描述

图1-15 fastpose.sim.onnx

清清爽爽,没有多余的节点,也非常 nice

2. alphapose推理

拿到我们想要的 onnx 后,接下来去 C++ 中执行下推理,直接 make run 运行下,如下所示:

在这里插入图片描述

图2-1 make run

来简单解读下代码,模型编译和之前没有任何区别,我们还是主要关注 inference,预处理部分通过 warpAffine 将图像缩放到 256x192,相比之前稍微做了下扩展,具体代码如下:

void get_preprocess_transform(const cv::Size& image_size, const cv::Rect& box, const cv::Size& net_size, float i2d[6], float d2i[6]){cv::Rect box_ = box;if(box_.width == 0 || box_.height == 0){box_.width  = image_size.width;box_.height = image_size.height;box_.x = 0;box_.y = 0;}float rate = box_.width > 100 ? 0.1f : 0.15f;float pad_width  = box_.width  * (1 + 2 * rate);float pad_height = box_.height * (1 + 1 * rate);float scale = min(net_size.width  / pad_width,  net_size.height / pad_height);i2d[0] = scale;  i2d[1] = 0;      i2d[2] = -(box_.x - box_.width  * 1 * rate + pad_width * 0.5)  * scale + net_size.width  * 0.5 + scale * 0.5 - 0.5;  i2d[3] = 0;      i2d[4] = scale;  i2d[5] = -(box_.y - box_.height * 1 * rate + pad_height * 0.5) * scale + net_size.height * 0.5 + scale * 0.5 - 0.5;cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);cv::Mat m2x3_d2i(2, 3, CV_32F, d2i);cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);
}

这个 warpAffine 后的图像 input-image 如下所示,它其实是有一个扩大的过程,比我们平时的情况要复杂一点点

在这里插入图片描述

图2-2 input-image

后处理由于我们是放在 onnx 的,因此直接获取的就是个关键点,根据置信度来进行过滤即可

检测效果如下图所示:

在这里插入图片描述

图2-3 image-draw

那这就是整个 alphapose 案例,有些地方看起来比较乱,还是需要自己多去实践,多去思考的,比如后处理就是这个算法的关键和核心,我们对官方代码进行解读后一定要自己实现一个版本,这样才能吸收消化从而变成我们自己的知识,还有一点,就是复杂的后处理放到 onnx 中可以解决很多问题

另外就是一个复杂的工程项目中要处理的问题太多了,但是我们要学会怎么化繁为简,这是我们要掌握的知识。

3. 讨论

姿态点估计算法可以分为自下而上和自上而下两种方法:(from chatGPT)

1. 自下而上方法:自下而上的姿态点估计算法是指先检测图像中所有可能的关键点,然后再通过关键点之间的关联关系来估计人体的姿态。这种方法通常从图像中检测出一系列的关键点,然后利用关键点之间的空间关系和约束关系来拟合出人体的姿态。

2. 自上而下方法:自上而下的姿态点估计算法是指先检测出人体的整体姿态或人体框,然后再在特定区域或人体框内估计关键点的位置。这种方法首先通过人体检测算法或目标检测算法找到人体的位置和姿态,然后在检测到的人体框内进行关键点的估计。

两种方法各有优势和适用场景:

  • 自下而上方法的优势在于可以处理多人姿态估计问题,因为它能够检测图像中所有可能的关键点,然后通过关联关系对多人姿态进行建模。这种方法在密集场景中表现较好,但在处理复杂场景时可能存在误检或漏检问题。
  • 自上而下方法的优势在于可以通过先验信息来辅助姿态估计,例如先进行人体检测或目标检测,然后再在检测到的人体框内进行关键点估计。这种方法通常比较高效,并且能够在复杂场景中保持稳定性,但可能不太适用于密集场景或多人姿态估计问题。

那很明显 alphapose 是自下而上的方法。

在 alphapose 中输入到网络中的是缩放到 256x192 尺寸的人体框,输出是一组热力图

在姿态点估计算法中,热力图(Heatmap)是一种用于表示关键点位置的图像。对于每个关键点,热力图是一个二维图像,其中每个像素的值表示该像素处是特定关键点的概率,热力图是如何转换成关键点坐标的呢?也就是后处理具体是如何做的呢?(这可能需要去仔细分析代码了😂)

那正常来说,整个 alphapose 的姿态点估计先通过检测器截取 box,再将截取得到的 box 送入到 alphapose 检测返回结果,如果人多的话,检测器截取到的每个 box 都要放到 alphapose 推理一遍,似乎有点耗时呀🤔

而且像多人密集场景,它则十分依赖检测器的能力,如果检测器的提取到的 box 不行,那后面的姿态点估计也就不准了,人少且比较分散效果应该不错

总结

这节课主要是学习姿态点估计网络 alphapose 的导出、编译到推理,这节课体现了一个非常重要的思想,那就是复杂的后处理放到 onnx 中去,这可以降低我们在 tensorRT 的复杂度。

同时这节课大部分时间都是在跟随杜老师不断解决各种各样的问题,我们实际工作中拿到一个工程项目文件也总是会遇到这样或者那样的问题,学习如何去解决问题才是我们要关注的,还是得多实践积累经验,能做到化繁为简,同时在理解完别人的代码后一定要自己实现一个版本,这样才能更好的去消化吸收变成我们自己的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83866.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java基础(八)二维数组

数组 二、二维数组 1. 二维数组使用步骤 定义二维数组 格式&#xff1a;数据类型 数组名[][]; 或 数据类型[][] 数组名; int scores[][]; int[][] scores;为二维数组元素分配内存 格式&#xff1a;数据类型 数组名[][]; 或 数据类型[][] 数组名; int scores[][]; scores …

什么是设计模式?

目录 概述: 什么是模式&#xff01;&#xff01; 为什么学习模式&#xff01;&#xff01; 模式和框架的比较&#xff1a; 设计模式研究的历史 关于pattern的历史 Gang of Four(GoF) 关于”Design”Pattern” 重提&#xff1a;指导模式设计的三个概念 1.重用(reuse)…

基于微信小程序的传染病酒店隔离平台设计与实现(Java+spring boot+MySQL+微信小程序)

获取源码或者论文请私信博主 演示视频&#xff1a; 基于微信小程序的传染病酒店隔离平台设计与实现&#xff08;Javaspring bootMySQL微信小程序&#xff09; 使用技术&#xff1a; 前端&#xff1a;html css javascript jQuery ajax thymeleaf 微信小程序 后端&#xff1a;…

【windows】windows上如何使用linux命令?

前言 windows上的bat命令感觉不方便&#xff0c;想在windows上使用linux命令。 有人提供了轮子&#xff0c;本文简单介绍一些该轮子的安装与使用&#xff0c;希望能够帮助到和我有一起需求的网友。 我的答案是busybox。 1.安装busybox.exe 在这个网站上安装busybox busyb…

两个状态的马尔可夫链

手动推导如下公式。 证明&#xff1a; 首先将如下矩阵对角化&#xff1a; { 1 − a a b 1 − b } \begin {Bmatrix} 1-a & a \\ b & 1-b \end {Bmatrix} {1−ab​a1−b​} (1)求如下矩阵的特征值&#xff1a; { 1 − a a b 1 − b } { x 1 x 2 } λ { x 1 x 2 }…

vscode终端背景颜色修改以及报错信息颜色修改

引言 刚从pycharm转到vscode上时&#xff0c;很不喜欢vscode终端信息一片白色&#xff0c;于是想尽办法去修改vscode终端风格 这里提供vscode终端背景颜色的修改和vscode终端报错提示信息颜色的修改方法 (1)vscode终端背景颜色优化 步骤一&#xff0c;ctrlshiftp打开设置搜索…

Unity-UGUI优化策略

界面出栈规则&#xff1a; 界面目录导航、策划界面回退需求造成界面套娃问题&#xff0c;夹带一系列层级问题&#xff0c;应该和策划进行友好沟通&#xff0c;避免界面不合理的出栈入栈规则 overdraw&#xff1a; 尽量减少同屏 半透明物体渲染 Unity 之 UGUI优化&#xff08;…

iOS开发-JsonModel的学习及使用

IOS JsonModel的学习及使用 当我们从服务端获取到json数据后的时候&#xff0c;我们需要在界面上展示或者保存起来&#xff0c;下面来看下直接通过NSDictionary取出数据的情况。 NSDictionary直接取出数据的诟病。 NSString *name [self.responseObj objectForKey:"nam…

github上有哪些值得读源码的react项目?

前言 下面是我整理的关于值得一读源码的react项目&#xff0c;希望对你有所帮助~ 1、 calcom Star: 21.6k calcom是一个开源的计算器应用程序。它提供了基本的数学运算功能&#xff0c;例如加法、减法、乘法和除法&#xff0c;还支持 科学计算、进制转换和单位转换等高级功能…

vmwera中安装的centos8出现ifconfig不可用

刚刚在虚拟机中装好centos结果发现自己的ifconfig命令不可用。 看一下环境变量里有没有ifconfig命令的路径&#xff0c;因为ifconfig是在/sbin路径下的&#xff0c;root用户登录进去才可以运行&#xff0c;先看一下root用户的环境变量。 root用户的环境变量里是有/sbin路径的&a…

java.lang.ClassNotFoundException: com.mysql.cj.jdbc.Driver的解决办法

springcloudAlibaba项目连接mysql时&#xff08;mysql版本8.0.31&#xff0c;Springboot2.2.2,spring cloud Hoxton.SR1,spring cloud alibaba 2.1.0.RELEASE&#xff09;&#xff0c;驱动名称报红&#xff0c;配置如下&#xff1a; 原因&#xff1a;引入的jdbc驱动包和使用的m…

pytest fixture 用于teardown工作

fixture通过scope参数控制setup级别&#xff0c;setup作为用例之前前的操作&#xff0c;用例执行完之后那肯定也有teardown操作。这里用到fixture的teardown操作并不是独立的函数&#xff0c;用yield关键字呼唤teardown操作。 举个例子&#xff1a; 输出&#xff1a; 说明&…

MongoDB文档-基础使用-在客户端(dos窗口)/可视化工具中使用MongoDB基础语句

阿丹&#xff1a; 本文章将描述以及研究mongodb在客户端的基础应用以及在spring-boot中整合使用mongodb来完成基本的数据增删改查。 传送门&#xff1a; MongoDB文档--基本概念_一单成的博客-CSDN博客 MongoDB文档--基本安装-linux安装&#xff08;mongodb环境搭建&#xff0…

Celery嵌入工程的使用

文章目录 1.config 1.1 通过app.conf进行配置1.2 通过app.conf.update进行配置1.3 通过配置文件进行配置1.4 通过配置类的方式进行配置2.任务相关 2.1 任务基类(base)2.2 任务名称(name)2.3 任务请求(request)2.4 任务重试(retry) 2.4.1 指定最大重试次数2.4.2 设置重试间隔时间…

RTC晶振两端要不要挂电容

发现GD32的RTC晶振两端需要挂电容&#xff0c;STM32的RTC晶振两端不需要挂电容。 STM32的RTC晶振两端&#xff0c;不需要挂电容&#xff0c;这样晶振启振很容易&#xff0c;挂大了&#xff0c;却难启动&#xff0c;且温度越低&#xff0c;启动越难。 有人说负载电容为6pF的晶振…

分享21年电赛F题-智能送药小车-做题记录以及经验分享

这里写目录标题 前言一、赛题分析1、车型选择2、巡线1、OpenMv循迹2、灰度循迹 3、装载药品4、识别数字5、LED指示6、双车通信7、转向方案1、开环转向2、位置环速度环闭环串级转向3、MPU6050转向 二、调试经验分享1、循迹2、识别数字3、转向4、双车通信5、逻辑处理6、心态问题 …

IoTDB1.X windows运行失败问题的处理

在windows运行 IoTDB1.x时 会出现如图所示的问题 为什么会出现这样的问题&#xff1f;java没有安装还是未调用成功&#xff0c;我是JAVA8~11~17各种更换都未能解决问题&#xff0c;最后对其bat文件进行查看&#xff0c;发现在conf\datanode-env.bat、conf\confignode-env.bat这…

深入学习JVM —— GC垃圾回收机制

前言 前面荔枝已经梳理了有关JVM的体系结构和类加载机制&#xff0c;也详细地介绍了JVM在类加载时的双亲委派模型&#xff0c;而在这篇文章中荔枝将会比较详细地梳理有关JVM学习的另一大重点——GC垃圾回收机制的相关知识&#xff0c;重点了解的比如对象可达性的判断、四种回收…

uniapp+vue3项目中使用vant-weapp

创建项目 通过vue-cli命令行创建项目 Vue3/Vite版要求 node 版本^14.18.0 || >16.0.0 uni-app官网 (dcloud.net.cn) npx degit dcloudio/uni-preset-vue#vite my-vue3-project打开项目 点击顶部菜单栏终端/新建终端 执行安装依赖指令 yarn install 或 npm install 安装vant…

2023年上海国际车展,英信翻译提供中日英同传服务

在2023年4月上海车展期间&#xff0c;日产汽车展示了一系列搭载智能网联技术和电驱动技术的车型&#xff0c;包括首次亮相的Max-Out概念车和专为中国消费者设计的纯电动SUV概念车——日产Arizon。备受全球汽车行业瞩目。 日产是日本第二大汽车公司&#xff0c;也是世界十大汽车…