关于PointHeadBox类的理解

forward函数

 def forward(self, batch_dict):"""Args:batch_dict:batch_size:point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)point_features_before_fusion: (N1 + N2 + N3 + ..., C)point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]point_labels (optional): (N1 + N2 + N3 + ...)gt_boxes (optional): (B, M, 8)Returns:batch_dict:point_cls_scores: (N1 + N2 + N3 + ..., 1)point_part_offset: (N1 + N2 + N3 + ..., 3)"""if self.model_cfg.get('USE_POINT_FEATURES_BEFORE_FUSION', False):point_features = batch_dict['point_features_before_fusion']else:point_features = batch_dict['point_features']#通过全连接层128-->256-->256-->3生成类别信息point_cls_preds = self.cls_layers(point_features)  # (total_points, num_class)#通过全连接层128-->256-->256-->8生成回归框信息point_box_preds = self.box_layers(point_features)  # (total_points, box_code_size)#在预测的3个类别中求出最大可能的类别作为标签信息,并经过sigmod函数point_cls_preds_max, _ = point_cls_preds.max(dim=-1)batch_dict['point_cls_scores'] = torch.sigmoid(point_cls_preds_max)ret_dict = {'point_cls_preds': point_cls_preds,'point_box_preds': point_box_preds}if self.training:#主要是生成每个点对应的真实的标签信息#以及真实框G相对于预测G_hat的框的参数偏移,每个点对应是1*8维向量targets_dict = self.assign_targets(batch_dict)ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']ret_dict['point_box_labels'] = targets_dict['point_box_labels']if not self.training or self.predict_boxes_when_training:#求出每个点对应的预测的标签信息#以及P相对于预测的框G_hat的参数偏移,每个点对应是1*8维向量point_cls_preds, point_box_preds = self.generate_predicted_boxes(points=batch_dict['point_coords'][:, 1:4],point_cls_preds=point_cls_preds, point_box_preds=point_box_preds)batch_dict['batch_cls_preds'] = point_cls_predsbatch_dict['batch_box_preds'] = point_box_predsbatch_dict['batch_index'] = batch_dict['point_coords'][:, 0]batch_dict['cls_preds_normalized'] = Falseself.forward_ret_dict = ret_dictreturn batch_dict

注意:对于每一个point,point_box_preds是1×8维向量,8维分别表示[xt, yt, zt, dxt, dyt, dzt, cost, sint],[xt, yt, zt]为中心点偏移量,[dxt, dyt, dzt]为长宽高偏移量,[cost, sint]为角度偏移量。

在这里插入图片描述

forward函数得到了每个前景点对应的真实标签值以及标注框信息;(self.assign_targets--------->self.assign_stack_targets-----> self.box_coder.encode_torch调用了PointResidualCoder类中的encode_torch函数)

得到了从G_hat到G的1*8维参数


每个前景点对应的预测标签值以及预测框信息;(self.generate_predicted_boxes--------->self.box_coder.decode_torch调用了PointResidualCoder类中的decode_torch函数)

得到了从P到G_hat的1*8维参数

得到这两组参数后用于后续计算损失时计算的box损失,采用的是L1回归损失

 point_loss_box_src = F.smooth_l1_loss(point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...])

边框回归(Bounding Box Regression)详解

PointResidualCoder

class PointResidualCoder(object):def __init__(self, code_size=8, use_mean_size=True, **kwargs):super().__init__()self.code_size = code_sizeself.use_mean_size = use_mean_sizeif self.use_mean_size:self.mean_size = torch.from_numpy(np.array(kwargs['mean_size'])).cuda().float()assert self.mean_size.min() > 0def encode_torch(self, gt_boxes, points, gt_classes=None):"""Args:gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]points: (N, 3) [x, y, z]gt_classes: (N) [1, num_classes]Returns:box_coding: (N, 8 + C)"""gt_boxes[:, 3:6] = torch.clamp_min(gt_boxes[:, 3:6], min=1e-5)xg, yg, zg, dxg, dyg, dzg, rg, *cgs = torch.split(gt_boxes, 1, dim=-1)xa, ya, za = torch.split(points, 1, dim=-1)if self.use_mean_size:assert gt_classes.max() <= self.mean_size.shape[0]point_anchor_size = self.mean_size[gt_classes - 1]dxa, dya, dza = torch.split(point_anchor_size, 1, dim=-1)diagonal = torch.sqrt(dxa ** 2 + dya ** 2)xt = (xg - xa) / diagonalyt = (yg - ya) / diagonalzt = (zg - za) / dzadxt = torch.log(dxg / dxa)dyt = torch.log(dyg / dya)dzt = torch.log(dzg / dza)else:xt = (xg - xa)yt = (yg - ya)zt = (zg - za)dxt = torch.log(dxg)dyt = torch.log(dyg)dzt = torch.log(dzg)cts = [g for g in cgs]return torch.cat([xt, yt, zt, dxt, dyt, dzt, torch.cos(rg), torch.sin(rg), *cts], dim=-1)def decode_torch(self, box_encodings, points, pred_classes=None):"""Args:box_encodings: (N, 8 + C) [x, y, z, dx, dy, dz, cos, sin, ...]points: [x, y, z]pred_classes: (N) [1, num_classes]Returns:"""xt, yt, zt, dxt, dyt, dzt, cost, sint, *cts = torch.split(box_encodings, 1, dim=-1)xa, ya, za = torch.split(points, 1, dim=-1)if self.use_mean_size:assert pred_classes.max() <= self.mean_size.shape[0]point_anchor_size = self.mean_size[pred_classes - 1]dxa, dya, dza = torch.split(point_anchor_size, 1, dim=-1)diagonal = torch.sqrt(dxa ** 2 + dya ** 2)xg = xt * diagonal + xayg = yt * diagonal + yazg = zt * dza + zadxg = torch.exp(dxt) * dxadyg = torch.exp(dyt) * dyadzg = torch.exp(dzt) * dzaelse:xg = xt + xayg = yt + yazg = zt + zadxg, dyg, dzg = torch.split(torch.exp(box_encodings[..., 3:6]), 1, dim=-1)rg = torch.atan2(sint, cost)cgs = [t for t in cts]return torch.cat([xg, yg, zg, dxg, dyg, dzg, rg, *cgs], dim=-1)

decode_torch:如何通过point_box_preds的8维向量得到proposal的7维坐标?将每一个point原始xyz坐标加上坐标偏移量[xt, yt, zt]即可得到proposal中心点坐标,利用作者预设的point_anchor_size乘上长宽高偏移量[dxt, dyt, dzt]得到proposal长宽高,利用atan2函数计算角度heading。

思考:

在这里插入图片描述
代码这里的Pw和Py参数直接是用的diagonal

diagonal = torch.sqrt(dxa ** 2 + dya ** 2)

个人的理解是觉得这样可以同时优化生成的anchor大小并且可以调节中心坐标的偏移。

assign_targets

 def assign_targets(self, input_dict):"""Args:input_dict:point_features: (N1 + N2 + N3 + ..., C)batch_size:point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]gt_boxes (optional): (B, M, 8)Returns:point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignoredpoint_part_labels: (N1 + N2 + N3 + ..., 3)"""point_coords = input_dict['point_coords']gt_boxes = input_dict['gt_boxes']assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)batch_size = gt_boxes.shape[0]extend_gt_boxes = box_utils.enlarge_box3d(gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH).view(batch_size, -1, gt_boxes.shape[-1])targets_dict = self.assign_stack_targets(points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,set_ignore_flag=True, use_ball_constraint=False,ret_part_labels=False, ret_box_labels=True)return targets_dict

extend_gt_boxes 主要是将groud truth boxex在长、宽、高方向上扩展

在这里插入图片描述

在这里插入图片描述

assign_stack_targets

#此函数传入的都是对应点的真实预测值和真实标注框def assign_stack_targets(self, points, gt_boxes, extend_gt_boxes=None,ret_box_labels=False, ret_part_labels=False,set_ignore_flag=True, use_ball_constraint=False, central_radius=2.0):"""Args:points: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]gt_boxes: (B, M, 8)extend_gt_boxes: [B, M, 8]ret_box_labels:ret_part_labels:set_ignore_flag:use_ball_constraint:central_radius:Returns:point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignoredpoint_box_labels: (N1 + N2 + N3 + ..., code_size)"""assert len(points.shape) == 2 and points.shape[1] == 4, 'points.shape=%s' % str(points.shape)assert len(gt_boxes.shape) == 3 and gt_boxes.shape[2] == 8, 'gt_boxes.shape=%s' % str(gt_boxes.shape)assert extend_gt_boxes is None or len(extend_gt_boxes.shape) == 3 and extend_gt_boxes.shape[2] == 8, \'extend_gt_boxes.shape=%s' % str(extend_gt_boxes.shape)assert set_ignore_flag != use_ball_constraint, 'Choose one only!'#将数据分批次处理batch_size = gt_boxes.shape[0]bs_idx = points[:, 0]point_cls_labels = points.new_zeros(points.shape[0]).long()point_box_labels = gt_boxes.new_zeros((points.shape[0], 8)) if ret_box_labels else Nonepoint_part_labels = gt_boxes.new_zeros((points.shape[0], 3)) if ret_part_labels else None#将数据分批次处理for k in range(batch_size):bs_mask = (bs_idx == k)#这里以*_single应该是中间缓存变量,作为每一批次处理的变量存储数据#points_single取出对应批次的点云的坐标信息points_single = points[bs_mask][:, 1:4]point_cls_labels_single = point_cls_labels.new_zeros(bs_mask.sum())#将每一个点云数据分配到真实标注框上box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(         points_single.unsqueeze(dim=0), gt_boxes[k:k + 1, :, 0:7].contiguous()).long().squeeze(dim=0)#box_idxs_of_pts是每个点对应分配的标注框索引值,没有匹配的赋值为-1box_fg_flag = (box_idxs_of_pts >= 0) #根据之前扩展的3D框计算被忽略的点if set_ignore_flag:#将每一个点云数据分配到扩展后的标注框上extend_box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(points_single.unsqueeze(dim=0), extend_gt_boxes[k:k+1, :, 0:7].contiguous()).long().squeeze(dim=0)fg_flag = box_fg_flag#异或运算,未扩展前没有包括,扩展后包含到的框,即被忽略的框ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)point_cls_labels_single[ignore_flag] = -1elif use_ball_constraint:box_centers = gt_boxes[k][box_idxs_of_pts][:, 0:3].clone()box_centers[:, 2] += gt_boxes[k][box_idxs_of_pts][:, 5] / 2ball_flag = ((box_centers - points_single).norm(dim=1) < central_radius)fg_flag = box_fg_flag & ball_flagelse:raise NotImplementedError#记录前景点信息,可以理解为论文中所说的前景点分割gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]#最后一维代表的是标注框对应的类别信息,对应前景点的类别信息point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()#记录一次批处理流程中所有点的类别信息point_cls_labels[bs_mask] = point_cls_labels_singleif ret_box_labels and gt_box_of_fg_points.shape[0] > 0:point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))#记录每一个前景点从G_hat到G的参数偏移,每个前景点最后输出是1*8维向量fg_point_box_labels = self.box_coder.encode_torch(gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],gt_classes=gt_box_of_fg_points[:, -1].long())point_box_labels_single[fg_flag] = fg_point_box_labelspoint_box_labels[bs_mask] = point_box_labels_singleif ret_part_labels:point_part_labels_single = point_part_labels.new_zeros((bs_mask.sum(), 3))transformed_points = points_single[fg_flag] - gt_box_of_fg_points[:, 0:3]transformed_points = common_utils.rotate_points_along_z(transformed_points.view(-1, 1, 3), -gt_box_of_fg_points[:, 6]).view(-1, 3)offset = torch.tensor([0.5, 0.5, 0.5]).view(1, 3).type_as(transformed_points)point_part_labels_single[fg_flag] = (transformed_points / gt_box_of_fg_points[:, 3:6]) + offsetpoint_part_labels[bs_mask] = point_part_labels_singletargets_dict = {'point_cls_labels': point_cls_labels,'point_box_labels': point_box_labels,'point_part_labels': point_part_labels}return targets_dict

经典框架解读 | 论文+代码 | 3D Detection | OpenPCDet | PointRCNN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/148489.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【算法】排序——归并排序和计数排序

主页点击直达&#xff1a;个人主页 我的小仓库&#xff1a;代码仓库 C语言偷着笑&#xff1a;C语言专栏 数据结构挨打小记&#xff1a;初阶数据结构专栏 Linux被操作记&#xff1a;Linux专栏 LeetCode刷题掉发记&#xff1a;LeetCode刷题 算法头疼记&#xff1a;算法专栏…

在Ubuntu 20.04搭建最小实验环境

sudo apt-get -y install --no-install-recommends wget gnupg ca-certificates安装导入GPG公钥所需的依赖包。 sudo wget -O - https://openresty.org/package/pubkey.gpg | sudo apt-key add -导入GPG密钥。 sudo apt-get -y install --no-install-recommends software-p…

【AI视野·今日NLP 自然语言处理论文速览 第四十七期】Wed, 4 Oct 2023

AI视野今日CS.NLP 自然语言处理论文速览 Wed, 4 Oct 2023 Totally 73 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Contrastive Post-training Large Language Models on Data Curriculum Authors Canwen Xu, Corby Rosset, Luc…

【VUE·疑难问题】定义 table 中每行的高度(使用 element-UI)

一、如何定义 table 中每一行的 height &#xff1f; 1.table例子 <!-- 二、table --><div style"overflow: hidden;display: block;height: 68vh;width: 100%;"><el-table stripe show-header style"width: 100%" :data"tableData&q…

密码技术 (6) - 证书

一. 前言 前面介绍的公钥密码和数字签名&#xff0c;都无法解决一个问题&#xff0c;那就是判断自己获取的公钥是否期望的&#xff0c;不能确定公钥是否被中间攻击人掉包。所以&#xff0c;证书的作用是用来证明公钥是否合法的。本文介绍的证书就是解决证书的可靠性的技术。 二…

macbook电脑磁盘满了怎么删东西?

macbook是苹果公司的一款高性能笔记本电脑&#xff0c;受到很多用户的喜爱。但是&#xff0c;如果macbook的磁盘空间不足&#xff0c;可能会导致一些问题&#xff0c;比如无法开机、运行缓慢、应用崩溃等。那么&#xff0c;macbook磁盘满了无法开机怎么办&#xff0c;macbook磁…

大数据-玩转数据-Flink SQL编程实战 (热门商品TOP N)

一、需求描述 每隔30min 统计最近 1hour的热门商品 top3, 并把统计的结果写入到mysql中。 二、需求分析 1.统计每个商品的点击量, 开窗2.分组窗口分组3.over窗口 三、需求实现 3.1、创建数据源示例 input/UserBehavior.csv 543462,1715,1464116,pv,1511658000 662867,22…

新版校园跑腿独立版小程序源码 多校版本,多模块,适合跑腿,外卖,表白,二手,快递等校园服务

最新校园跑腿小程序源码 多校版本&#xff0c;多模块&#xff0c;适合跑腿&#xff0c;外卖&#xff0c;表白&#xff0c;二手&#xff0c;快递等校园服务 此版本为独立版本&#xff0c;不需要** 直接放入就可以 需要自己准备好后台的服务器&#xff0c;已认证的小程序&#xf…

网络基础入门(网络基础概念详解)

本篇文章主要是对网络初学的概念进行解释&#xff0c;可以让你对网络有一个大概整体的认知。 文章目录 一、简单认识网络 1、1 什么是网络 1、2 网络分类 二、网络模型 2、1OSI七层模型 2、1、1 简单认识协议 2、1、2 OSI七层模型解释 2、2 TCP/IP五层(或四层)模型 三、网络传…

26-网络通信

网络通信 什么是网络编程&#xff1f; 可以让设备中的程序与网络上其他设备中的程序进行数据交互&#xff08;实现网络通信的&#xff09;。 java.net.包下提供了网络编程的解决方案&#xff01; 基本的通信架构有2种形式&#xff1a;CS架构&#xff08; Client客户端/Server服…

深度学习:基于长短时记忆网络LSTM实现情感分析

目录 1 LSTM网络介绍 1.1 LSTM概述 1.2 LSTM网络结构 1.3 LSTM门机制 1.4 双向LSTM 2 Pytorch LSTM输入输出 2.1 LSTM参数 2.2 LSTM输入 2.3 LSTM输出 2.4 隐藏层状态初始化 3 基于LSTM实现情感分析 3.1 情感分析介绍 3.2 数据集介绍 3.3 基于pytorch的代码实现 3…

Ventoy万能U盘安装系统,支持任何的操作系统安装

Ventoy万能U盘安装系统&#xff0c;支持任何的操作系统安装&#xff1a; Download . VentoyVentoy is an open source tool to create bootable USB drive for ISO files. With ventoy, you dont need to format the disk again and again, you just need to copy the iso fil…

LLMs 用强化学习进行微调 RLHF: Fine-tuning with reinforcement learning

让我们把一切都整合在一起&#xff0c;看看您将如何在强化学习过程中使用奖励模型来更新LLM的权重&#xff0c;并生成与人对齐的模型。请记住&#xff0c;您希望从已经在您感兴趣的任务上表现良好的模型开始。您将努力使指导发现您的LLM对齐。首先&#xff0c;您将从提示数据集…

数据结构 2.1 线性表的定义和基本操作

数据结构三要素——逻辑结构、数据的运算、存储结构&#xff08;物理结构&#xff09; 线性表的逻辑结构 线性表是具有相同数据类型的n&#xff08;n>0&#xff09;个数据元素的有限序列&#xff0c;其中n为表长&#xff0c;当n0时&#xff0c;线性表是一个空表。 每个数…

亲,您的假期余额已经严重不足了......

引言 大家好&#xff0c;我是亿元程序员&#xff0c;一位有着8年游戏行业经验的主程。 转眼八天长假已经接近尾声了&#xff0c;今天来总结一下大家的假期&#xff0c;聊一聊假期关于学习的看法&#xff0c;并预估一下大家节后大家上班时的样子。 1.放假前一天 即将迎来八天…

『力扣每日一题12』:只出现一次的数字

一、题目 给你一个 非空 整数数组 nums &#xff0c;除了某个元素只出现一次以外&#xff0c;其余每个元素均出现两次。找出那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法来解决此问题&#xff0c;且该算法只使用常量额外空间。 示例 1 &#xff1a; 输入&…

QQ怎么上传大于1G的视频啊?视频压缩这样做

当我们想要在QQ上分享一段大容量的视频时&#xff0c;往往会因为超过1G的限制而感到无助。不过&#xff0c;不用担心&#xff0c;今天我们将为你介绍三种可以压缩视频大小的方法&#xff0c;一起来看看吧~ 一、嗨格式压缩大师 嗨格式压缩大师是一款专业的视频压缩软件&#xf…

【Excel】快速提取某个符号前面的数据内容

【问题描述】 在使用excel整理数据过程中&#xff0c;经常与需要调整数据后&#xff0c;进行使用。 例如凭证导出后&#xff0c;科目列是包含科目编码和科目名称的。 但由于要将数据复制到其他的导入模板上使用&#xff0c;对应的模板只需要科目编码&#xff0c;不需要科目名称…

Spring 原理

它是一个全面的、企业应用开发一站式的解决方案&#xff0c;贯穿表现层、业务层、持久层。但是 Spring仍然可以和其他的框架无缝整合。 1 Spring 特点 轻量级控制反转面向切面容器框架集合 2 Spring 核心组件 3 Spring 常用模块 4 Spring 主要包 5 Spring 常用注解 bean…

Springboot学习笔记——1

Springboot学习笔记——1 一、快速上手Springboot1.1、Springboot入门程序开发1.1.1、IDEA创建Springboot项目1.1.2、官网创建Springboot项目1.1.3、阿里云创建Springboot项目1.1.4、手工制作Springboot项目 1.2、隐藏文件或文件夹1.3、入门案例解析1.3.1、parent1.3.2、starte…