AlphaPose Pytorch 代码详解(一):predict

前言

代码地址:AlphaPose-Pytorch版

本文以图像 1.jpg(854x480)为例对整个预测过程的各个细节进行解读并记录

python demo.py --indir examples/demo --outdir examples/res --save_img

在这里插入图片描述

1. YOLO

1.1 图像预处理

  • cv2读取BGR图像 img [480,854,3] (h,w,c)
  • 按参数 inp_dim=608,将图像保持长宽比缩放至 [341,608,3],并以数值为128做padding至 [608,608,3]
  • BGR转RGB、维度处理、转tensor、数据类型转float、数值除以255做归一化,[1,3,608,608] (b,c,h,w)

输出:

img:		[1,3,608,608]	(yolo输入图像)
orig_img:	[480,854,3]		(原始BGR图像)
im_name:	'examples/demo/1.jpg'
im_dim_list:[854,480,854,480](原图尺寸, 用于把yolo的输出坐标转换到原图坐标系)

1.2 yolo 模型推理

输入:img[1,3,608,608]
输出:pred[1,22743,85]

1 = b a t c h s i z e \mathrm{1=batchsize} 1=batchsize

22743 = [ ( 608 / 32 ) 2 + ( 608 / 16 ) 2 + ( 608 / 8 ) 2 ] × 3 22743=[(608/32)^2+(608/16)^2+(608/8)^2]\times3 22743=[(608/32)2+(608/16)2+(608/8)2]×3

85 = [ x , y , w , h , c o n f , 80 c l a s s e s ] \mathrm{85=[x,y,w,h,conf,80classes]} 85=[x,y,w,h,conf,80classes]

1.3 输出后处理

(1)第一阶段

  • 坐标 xywhxyxy
  • batchsize 循环,[22743,85][22743,7] 7 = [ x , y , x , y , c o n f , c l a s s s c o r e , c l a s s ] \mathrm{7=[x,y,x,y,conf,class\ score,class]} 7=[x,y,x,y,conf,class score,class],即80类得分转换为得分最高的类别索引及其得分
  • 去除 c o n f ≤ 0.05 \mathrm{conf\le0.05} conf0.05 的项 [37,7]
  • 保留类别为人的项,并且按 conf 从高到低排序,得到结果 img_pred[19,7]
  • nms去除重复目标 [19,7]->[6,7]
  • 添加 batch_idx 这里batchsize为1所以都是0,[6,7]->[6,8] 8 = [ b a t c h i d x , x , y , x , y , c o n f , c l a s s s c o r e , c l a s s ] \mathrm{8=[batch\ idx, x,y,x,y,conf,class\ score,class]} 8=[batch idx,x,y,x,y,conf,class score,class]
  • 坐标数值转换,从 [608,608] 转到原图坐标 [854,480],并把坐标 clamp[0,w] [0,h] 之间

输出:

orig_img:[480,854,3]	(原始BGR图像)
im_name:'examples/demo/1.jpg'
boxes:	[6,4]	(x,y,x,y)(原图坐标系)
scores:	[6,1]	(conf)

NMS 细节

  • nms_conf=0.6
  • img_pred 中的第一项放到结果中
  • 剩余所有项与结果中的最后一项计算iou,保留 iou<nms_conf 的项作为新的 img_pred
  • 循环直到 img_pred 中没有目标
  • 当经过nms后的目标数量大于100个,会把 nms_conf-0.05,从最初的 img_pred 开始重新进行nms

(2)第二阶段

  • 原始图像 orig_img [480,854,3] BGR转RGB、维度处理、转tensor、数据类型转float、数值除以255做归一化,得到 inp [3,480,854]
  • 对三通道做处理 inp[0].add_(-0.406), inp[1].add_(-0.457), inp[2].add_(-0.480)
  • 扩大 boxes 中目标框的范围,并把左上角坐标存入 pt1,右下角坐标存入 pt2
  • 根据 boxes 把每个目标从图中抠出来,通过保比例缩放+zero padding,统一成 [3,320,256] 大小的图像存入 inps

输出:

inps:	[6,3,320,256]	(检测目标的子图像,作为Alphapose的输入)
orig_img:[480,854,3]	(原始BGR图像)
im_name:'examples/demo/1.jpg'
boxes:	[6,4]	(x1,y1,x2,y2)(yolo原始输出,原图坐标系)
scores:	[6,1]	(yolo输出conf)
pt1:	[6,2]	(x1,y1)(yolo输出扩大后坐标,原图坐标系)
pt2:	[6,2]	(x2,y2)(yolo输出扩大后坐标,原图坐标系)

2. POSE

2.1 pose 模型推理

输入:inps[6,3,320,256]
输出:hm[6,17,80,64],即6个目标,每个目标17个关键点对应的热力图

2.2 输出后处理

(1)第一阶段:热力图转坐标

  • 获取 hm[6,17,80,64] 中每个关键点的热力图中最大值的索引 preds[6,17,2]
  • 由于 opt.matching=False,此处使用简单的后处理,源码如下
    preds 中某个索引 [x,y] 为例,取出其热力图中相邻的上下左右四个位置的值,并且分别在 x x x y y y 轴上往较高的方向偏移 0.25 0.25 0.25
    x x x 轴为例: p l e f t = h m [ y ] [ x − 1 ] p_\mathrm{left}=\mathrm{hm}[y][x-1] pleft=hm[y][x1] p r i g h t = h m [ y ] [ x + 1 ] p_\mathrm{right}=\mathrm{hm}[y][x+1] pright=hm[y][x+1],若 p l e f t > p r i g h t p_\mathrm{left}>p_\mathrm{right} pleft>pright x + 0.25 x+0.25 x+0.25,若 p l e f t = p r i g h t p_\mathrm{left}=p_\mathrm{right} pleft=pright x x x 保持不变
    最后会在所有的坐标值上 + 0.2 +0.2 +0.2
for i in range(preds.size(0)):for j in range(preds.size(1)):hm = hms[i][j]pX, pY = int(round(float(preds[i][j][0]))), int(round(float(preds[i][j][1])))if 0 < pX < opt.outputResW - 1 and 0 < pY < opt.outputResH - 1:diff = torch.Tensor((hm[pY][pX + 1] - hm[pY][pX - 1], hm[pY + 1][pX] - hm[pY - 1][pX]))preds[i][j] += diff.sign() * 0.25
preds += 0.2
  • 目前得到的坐标 preds[6,17,2] 是相对于输出分辨率 [80,64] 坐标系下的,转换到原图分辨率 [480,854] 坐标系下,得到 preds_tf[6,17,2]

输出:

preds:		[6,17,2]	(经过第二步偏移处理后的坐标,相对于热力图坐标系)
preds_tf:	[6,17,2]	(最终坐标,相对于原图坐标系)
maxval:		[6,17,1]	(热力图最大值)

(2)第二阶段:pose nms

输入:

ori_bboxs:		[6,4]		(yolo原始输出,原图坐标系)
ori_bbox_scores:[6,1]		(yolo输出conf)
ori_pose_preds:	[6,17,2]	(对应preds_tf,关键点坐标,原图坐标系)
ori_pose_scores:[6,17,1]	(对应maxval,热力图最大值)
  • 根据 bboxs 计算每个目标框的 w,h,选择每个目标框中的最大值 max(w,h) 并乘上 alpha=0.1 构成 ref_dists[6]
  • 根据 pose_scores 计算每个目标17个关键点得分的均值,得到 human_scores[6]
  • 开始循环,直到 human_scores 无目标
    1. 选择 human_scores 最高的目标,坐标和得分分别记为 pick_preds[17,2], pick_scores[17,1]
      全部的坐标和得分记为 all_preds[6,17,2], all_scores[6,17,1](此处命名方式与源码略有不同,以便于区分)
    2. 计算距离:final_dist[6] 目标的同类别关键点的距离,距离越近数值越大
      score_dists 计算位置距离非常近的同类别关键点的得分距离
      point_dist = e − d / 2.65 =e^{-d/2.65} =ed/2.65,因为 d ≥ 0 d\ge0 d0,所以 0 < p o i n t _ d i s t ≤ 1 0<\mathrm{point\_dist}\le1 0<point_dist1 d d d 越小, p o i n t _ d i s t \mathrm{point\_dist} point_dist 越大,目标本身则最大全为1
    3. 计算关键点匹配数量:num_match_keypoints[6] 目标之间同类别关键点中距离较近的数量
    4. 去除多余目标:目标之间的距离超过阈值 or 目标之间距离相近的关键点数量超过阈值 → 判定为多余的目标。
      由于选出的目标本身也在其中,因此目标自身必然在去除的队伍中,如果除了自身还有目标被去除,那么会把额外的目标与自身的索引放在一起得到 merge_ids,这些目标相互之间距离很近,用于后续融合目标。
对应第2def get_parametric_distance(i, all_preds, all_scores):pick_preds, pick_scores = all_preds[i], all_scores[i]'计算坐标位置的欧氏距离 dist[6,17](同类别关键点之间的距离)'dist = torch.sqrt(torch.sum(torch.pow(pick_preds[np.newaxis, :] - all_preds, 2), dim=2))'计算dist<=1的点之间的得分距离 score_dists[6,17]'mask = (dist <= 1)score_dists = torch.zeros(all_preds.shape[0], 17)score_dists[mask] = torch.tanh(pick_scores[mask]/delta1) * torch.tanh(all_scores[mask]/delta1)  'delta1=1''final_dist[6]'point_dist = torch.exp((-1) * dist / delta2)  'delta2=2.65'final_dist = torch.sum(score_dists, dim=1) + mu * torch.sum(point_dist, dim=1)  'mu=1.7'return final_dist
对应第3def PCK_match(pick_pred, all_preds, ref_dist):dist = torch.sqrt(torch.sum(torch.pow(pick_preds[np.newaxis, :] - all_preds, 2), dim=2))ref_dist = min(ref_dist, 7)num_match_keypoints = torch.sum(dist / ref_dist <= 1, dim=1)return num_match_keypoints
对应第4'gamma=22.48, matchThreds=5'
delete_ids = torch.from_numpy(np.arange(human_scores.shape[0]))[(final_dist > gamma) | (num_match_keypoints >= matchThreds)]

输出:

merge_ids:			[6,x]
preds_pick:			[6,17,2]
scores_pick:		[6,17,1]
bbox_scores_pick:	[6,1]
'''
这里的输出是从各个输入 orig_xxxx, 例如 ori_bbox_scores 中挑选出来的(nms后的目标)
只是在这个例子中, nms判断并没有重复的目标, 因此和原始输入保持一致merge_ids 是一个列表, x代表每一项的长度, 本例中x都=1
如果nms判断存在重复目标, 那么会把这些目标在原始输入中的索引记录在 merge_ids 中, 此时x>1
在第三阶段中会把这些目标进行融合
'''

(3)第三阶段:融合与过滤

  • 去除17个关键点中最高得分 m a x _ s c o r e < s c o r e T h r e d s = 0.3 \mathrm{max\_score < scoreThreds = 0.3} max_score<scoreThreds=0.3 的目标
  • 融合目标,具体看下面代码,简单来说就是把距离比较近的关键点根据得分的高低作为权重,把坐标位置和得分进行加权求和作为融合后的目标
  • 去除融合后,17个关键点中最高得分 m a x _ s c o r e < s c o r e T h r e d s = 0.3 \mathrm{max\_score < scoreThreds = 0.3} max_score<scoreThreds=0.3 的目标
  • 根据能包含目标所有关键点的矩形框面积来过滤目标,1.5**2 * (xmax-xmin) * (ymax-ymin) < areaThres=0,具体为外接矩形长宽都乘1.5后计算面积,由于这里阈值为0,过滤基本无效
  • 最后会把所有关键点坐标数值 − 0.3 -0.3 0.3,并且根据关键点得分和目标框得分生成 proposal_score,具体见下面代码

此阶段过滤掉了一个目标,最终得到5个目标。

merge_pose, merge_score = p_merge_fast(preds_pick[j], ori_pose_preds[merge_id], ori_pose_scores[merge_id], ref_dists[pick[j]])def p_merge_fast(ref_pose, cluster_preds, cluster_scores, ref_dist):'''Score-weighted pose mergingINPUT:															本博客中文别称ref_pose:       reference pose          -- [17, 2]			挑选目标关键点cluster_preds:  redundant poses         -- [n, 17, 2]		多余目标关键点(挑选目标本身包含在多余目标中)cluster_scores: redundant poses score   -- [n, 17, 1]ref_dist:       reference scale         -- ConstantOUTPUT:final_pose:     merged pose             -- [17, 2]final_score:    merged score            -- [17]''''计算与多余目标关键点距离 dist[n,17]'dist = torch.sqrt(torch.sum(torch.pow(ref_pose[np.newaxis, :] - cluster_preds, 2),dim=2))kp_num = 17'回顾一下, ref_dist是挑选目标的目标框的 max(h,w)*0.1'ref_dist = min(ref_dist, 15)mask = (dist <= ref_dist)final_pose = torch.zeros(kp_num, 2)final_score = torch.zeros(kp_num)if cluster_preds.dim() == 2:cluster_preds.unsqueeze_(0)cluster_scores.unsqueeze_(0)if mask.dim() == 1:mask.unsqueeze_(0)# Weighted Merge'根据pose的得分来决定每个目标所占的比例, 具体为该得分占总得分的比例'masked_scores = cluster_scores.mul(mask.float().unsqueeze(-1))normed_scores = masked_scores / torch.sum(masked_scores, dim=0)'根据计算得到的比例做加权和, 得到最终的pose及其得分'final_pose = torch.mul(cluster_preds, normed_scores.repeat(1, 1, 2)).sum(dim=0)final_score = torch.mul(masked_scores, normed_scores).sum(dim=0)return final_pose, final_score
final_result.append({'keypoints': merge_pose - 0.3,'kp_score': merge_score,'proposal_score': torch.mean(merge_score) + bbox_scores_pick[j] + 1.25 * max(merge_score)
})keypoints 	[17,2]
kp_score	[17,1]
proposal_score [1]

3. Alphapose 网络结构

3.1 总流程

在这里插入图片描述

  1. SEResnet 作为 backbone 提取特征
  2. nn.PixelShuffle(2) 提升分辨率
  3. 经过两个 DUC 模块进一步提升分辨率
  4. 通过一个卷积得到输出
  5. 此时获得的输出如图所示 out[6,33,80,64] 有33个关键点,通过 out.narrow(1, 0, 17) 获取前17个关键点作为最终的输出 hm[6,17,80,64]

3.2 DUC 模块

  1. 2个DUC模块结构相同,都是先用卷积升维,再用一个 nn.PixelShuffle(2) 提升分辨率
  2. 图中以 DUC1 模块的参数为例进行绘制

3.3 PixelShuffle 操作

import torch
import torch.nn as nninput_tensor = torch.arange(1, 17).view(1, 16, 1, 1).float()
pixel_shuffle = nn.PixelShuffle(2)
output_tensor = pixel_shuffle(input_tensor)print(output_tensor)>>>
tensor([[[[ 1.,  2.],[ 3.,  4.]],[[ 5.,  6.],[ 7.,  8.]],[[ 9., 10.],[11., 12.]],[[13., 14.],[15., 16.]]]])

3.4 SEResnet 框架

在这里插入图片描述
  图中省略了 batchsize 维度,主要分为4层,分别相对原图下采样4、8、16、32倍

3.5 SEResnet 细节

在这里插入图片描述
  仿照代码,把这4个由 Bottleneck_SEBottleneck 构成的层级记作 l a y e r 1 ∼ 4 \mathrm{layer1\sim4} layer14,图中为 l a y e r 1 \mathrm{layer1} layer1 的数据。

  每个层中两种 Bottleneck 都会通过三个卷积层,先把特征维度控制为输出特征维度的 1/4,第二个保持不变,第三个达到输出特征维度,再以第二层为例:
l a y e r 2 : \mathrm{layer2:} layer2:
   B o t t l e n e c k _ S E : 256 → 128 → 128 → 512 \mathrm{Bottleneck\_ SE:256\to128\to128\to512} Bottleneck_SE:256128128512
   B o t t l e n e c k : 512 → 128 → 128 → 512 \mathrm{Bottleneck:512\to128\to128\to512} Bottleneck:512128128512

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/156959.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LATR:3D Lane Detection from Monocular Images with Transformer

参考代码&#xff1a;LATR 动机与主要工作&#xff1a; 之前的3D车道线检测算法使用诸如IPM投影、3D anchor加NMS后处理等操作处理车道线检测&#xff0c;但这些操作或多或少会存在一些负面效应。IPM投影对深度估计和相机内外参数精度有要求&#xff0c;anchor的方式需要一些如…

【图像融合】差异的高斯:一种简单有效的通用图像融合方法[用于融合红外和可见光图像、多焦点图像、多模态医学图像和多曝光图像](Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

淘宝价格,淘宝商品优惠券数据接口,淘宝商品销量接口,淘宝商品详情数据接口,淘宝API接口

淘宝价格和商品优惠券数据接口是淘宝平台提供的官方数据接口&#xff0c;通过调用接口&#xff0c;可以获取到淘宝商品的价格信息和优惠券数据。 获取淘宝价格和商品优惠券数据接口的步骤如下&#xff1a; 输入淘宝网址登陆淘宝账号密码。点击获取key和secret。调用获取buyer…

android 与 flutter 之间的通信

文章目录 前言集成 flutter 混合开发android 与 flutter 之间的通信总结 一、前言 因为flutter 具有跨平台的属性&#xff0c;既可以在android上跑&#xff0c;也能在ios 上跑&#xff0c;所以为了节约开发的成本&#xff0c;减少人力&#xff0c;势必就会用到它。然而已有的…

04在命令行中使用Maven命令创建Maven版的Web工程,并将工程部署到服务器的步骤

创建Maven版的Web工程 使用命令生成Web工程 使用mvn archetype:generate命令生成Web工程时&#xff0c;需要使用一个专门生成Web工程骨架的archetype(参照官网看到它的用法) -D表示后面要附加命令的参数&#xff0c;字母D和后面的参数是紧挨着的&#xff0c;中间没有任何其它…

Ceph介绍与部署

Ceph介绍与部署 一、存储基础1.1、单机存储设备1.1.1、单机存储的问题 1.2、商业存储解决方案1.3、分布式存储&#xff08;软件定义的存储 SDS&#xff09;1.3.1、分布式存储的类型 二、Ceph 简介三、Ceph 优势四、Ceph 架构五、Ceph 核心组件5.1、Pool中数据保存方式支持两种类…

FlashDuty Changelog 2023-09-21 | 自定义字段和开发者中心

FlashDuty&#xff1a;一站式告警响应平台&#xff0c;前往此地址免费体验&#xff01; 自定义字段 FlashDuty 已支持接入大部分常见的告警系统&#xff0c;我们将推送内容中的大部分信息放到了 Lables 进行展示。尽管如此&#xff0c;我们用户还是会有一些扩展或定制性的需求…

数据可视化

pip install matplotlib 一、各种图 #线形图 import numpy as np import pandas as pd df1pd.DataFrame(datanp.random.randn(1000,4),indexpd.date_range(start10/10/2023,periods1000),columnslist(ABCD)) df1.cumsum().plot()#2、条形图 df2pd.DataFrame(datanp.random.ra…

文本编辑器去除PDF水印

用文本编辑器打开pdf&#xff0c;搜索水印的特殊文字&#xff0c;全部替换。 另外一个水印字母间有空格。 替换完后保存。 重新打开pdf&#xff1a;

通常用哪些软件做数据可视化大屏?

数据可视化大屏是一种展示数据的方式&#xff0c;通过使用软件工具可以将数据以图表、地图、仪表盘等形式直观地呈现给用户。以下是常用的数据可视化大屏软件及其特点&#xff1a; 1. Datainside&#xff1a; - 特点&#xff1a;Datainside是一款功能强大且易于使用的数据可视…

绘制X-Bar-S和X-Bar-R图,监测过程,计算CPK过程能力指数

X-Bar-S图和X-Bar-R图是统计质量控制中常用的两种控制图&#xff0c;用于监测过程的稳定性和一致性。它们的主要区别在于如何计算和呈现数据的变化以及所关注的问题类型。 X-Bar-S图&#xff08;平均值与标准偏差图&#xff09;&#xff1a; X-Bar代表样本均值&#xff0c;S代表…

HttpServletRequest对象与RequestDispatcher对象

一、HttpServletRequest对象 1.介绍 在Servlet API中&#xff0c;定义了一个HttpServletRequest接口&#xff0c;它继承自ServletRequest接口&#xff0c;专门用来封装HTTP请求消息。由于HTTP请求消息分为请求行、请求消息头和请求消息体三部分&#xff0c;因此&#xff0c;在…

ODrive移植keil(五)—— 开环控制和电流变换

目录 一、开环控制1.1、控制原理1.2、硬件接线1.3、代码说明1.4、程序演示1.5、程序架构的体现 二、电流变换2.1、理论说明2.2、代码说明 ODrive、VESC和SimpleFOC 教程链接汇总&#xff1a;请点击 一、开环控制 在SimpleFOC系列中有开环控制的教程&#xff0c;SimpleFOC移植S…

《Java极简设计模式》第08章:外观模式(Facade)

作者&#xff1a;冰河 星球&#xff1a;http://m6z.cn/6aeFbs 博客&#xff1a;https://binghe.gitcode.host 文章汇总&#xff1a;https://binghe.gitcode.host/md/all/all.html 源码地址&#xff1a;https://github.com/binghe001/java-simple-design-patterns/tree/master/j…

SpringBoot篇之集成Mybatis-plus

目录 前言一、Mybatis-plus介绍1.1 官网 二、代码生成器总结 前言 大家好&#xff0c;我是AK&#xff0c;整理的SpringBoot集成Mybatis-plus以及代码生成器的使用&#xff0c;时间原因简单的整理下&#xff0c;有问题的可以评论区见或私信我。 一、Mybatis-plus介绍 1.1 官网…

完整教程:Java+Vue+Websocket实现OSS文件上传进度条功能

引言 文件上传是Web应用开发中常见的需求之一&#xff0c;而实时显示文件上传的进度条可以提升用户体验。本教程将介绍如何使用Java后端和Vue前端实现文件上传进度条功能&#xff0c;借助阿里云的OSS服务进行文件上传。 技术栈 后端&#xff1a;Java、Spring Boot 、WebSock…

23种经典设计模式:单例模式篇(C++)

前言&#xff1a; 博主将从此篇单例模式开始逐一分享23种经典设计模式&#xff0c;并结合C为大家展示实际应用。内容将持续更新&#xff0c;希望大家持续关注与支持。 什么是单例模式&#xff1f; 单例模式是设计模式的一种&#xff08;属于创建型模式 (Creational Pa…

PHP 员工工资管理系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 员工工资管理系统是一套完善的web设计系统&#xff0c;对理解php编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 php员工工资管理系统 代码 https://download.csdn.net/download/qq_41221322/884215…

python+opencv+深度学习实现二维码识别 计算机竞赛

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; pythonopencv深度学习实现二维码识别 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;3分 该项目较为新颖&…

Python高效实现网站数据挖掘

在当今互联网时代&#xff0c;SEO对于网站的成功至关重要。而Python爬虫作为一种强大的工具&#xff0c;为网站SEO带来了革命性的改变。通过利用Python爬虫&#xff0c;我们可以高效地实现网站数据挖掘和关键词分析&#xff0c;从而优化网站的SEO策略。本文将为您详细介绍如何利…