基于ROPNet项目训练modelnet40数据集进行3d点云的配置

项目地址: https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)(ICCV Workshop 2021)中获得了第二名。项目可以在win10环境下运行。
论文地址: https://arxiv.org/abs/2107.02583

网络简介: 一种新的深度学习模型,该模型利用具有区别特征的代表性重叠点进行配准,将部分到部分配准转换为部分完全配准。基于pointnet输出的特征设计了一个上下文引导模块,使用一个编码器来提取全局特征来预测点重叠得分。为了更好地找到有代表性的重叠点,使用提取的全局特征进行粗对齐。然后,引入一种变压器来丰富点特征,并基于点重叠得分和特征匹配去除非代表性点。在部分到完全的模式下建立相似度矩阵,最后采用加权支持向量差来估计变换矩阵。
在这里插入图片描述
实施效果: 从数据上看ROPNet与RPMNet与保持了断崖式的领先地位
在这里插入图片描述

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/ROPNet,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入ROPNet-master/src目录,执行以下安装命令。如果已经安装了torch 环境和open3d包,则不用再进行安装了

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118pip install open3d

1.3 模型与数据下载

modelnet40数据集 here [435M]
数据集下载后存储为以下路径即可。
在这里插入图片描述

官网预训练模型,无。
第三方预训练模型:使用ROPNet项目在modelnet40数据集上训练的模型

2、关键代码

2.1 dataloader

作者所提供的dataloader只能加载https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 数据集,其所返回的tgt_cloud, src_cloud实质上是基于一个点云采样而来的。 其中的self.label2cat, self.cat2label, self.symmetric_labels等对象代码实际上是没有任何作用的。

import copy
import h5py
import math
import numpy as np
import os
import torchfrom torch.utils.data import Dataset
import sysBASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOR_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOR_DIR)
from utils import  random_select_points, shift_point_cloud, jitter_point_cloud, \generate_random_rotation_matrix, generate_random_tranlation_vector, \transform, random_crop, shuffle_pc, random_scale_point_cloud, flip_pchalf1 = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl','car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser','flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp']
half1_symmetric = ['bottle', 'bowl', 'cone', 'cup', 'flower_pot', 'lamp']half2 = ['laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano','plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool','table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
half2_symmetric = ['tent', 'vase']class ModelNet40(Dataset):def __init__(self, root, split, npts, p_keep, noise, unseen, ao=False,normal=False):super(ModelNet40, self).__init__()self.single = False # for specific-class visualizationassert split in ['train', 'val', 'test']self.split = splitself.npts = nptsself.p_keep = p_keepself.noise = noiseself.unseen = unseenself.ao = ao # Asymmetric Objectsself.normal = normalself.half = half1 if split in 'train' else half2self.symmetric = half1_symmetric + half2_symmetricself.label2cat, self.cat2label = self.label2category(os.path.join(root, 'shape_names.txt'))self.half_labels = [self.cat2label[cat] for cat in self.half]self.symmetric_labels = [self.cat2label[cat] for cat in self.symmetric]files = [os.path.join(root, 'ply_data_train{}.h5'.format(i))for i in range(5)]if split == 'test':files = [os.path.join(root, 'ply_data_test{}.h5'.format(i))for i in range(2)]self.data, self.labels = self.decode_h5(files)print(f'split: {self.split}, unique_ids: {len(np.unique(self.labels))}')if self.split == 'train':self.Rs = [generate_random_rotation_matrix() for _ in range(len(self.data))]self.ts = [generate_random_tranlation_vector() for _ in range(len(self.data))]def label2category(self, file):with open(file, 'r') as f:label2cat = [category.strip() for category in f.readlines()]cat2label = {label2cat[i]: i for i in range(len(label2cat))}return label2cat, cat2labeldef decode_h5(self, files):points, normal, label = [], [], []for file in files:f = h5py.File(file, 'r')cur_points = f['data'][:].astype(np.float32)cur_normal = f['normal'][:].astype(np.float32)cur_label = f['label'][:].flatten().astype(np.int32)if self.unseen:idx = np.isin(cur_label, self.half_labels)cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]if self.ao and self.split in ['val', 'test']:idx = ~np.isin(cur_label, self.symmetric_labels)cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]if self.single:idx = np.isin(cur_label, [8])cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]points.append(cur_points)normal.append(cur_normal)label.append(cur_label)points = np.concatenate(points, axis=0)normal = np.concatenate(normal, axis=0)data = np.concatenate([points, normal], axis=-1).astype(np.float32)label = np.concatenate(label, axis=0)return data, labeldef compose(self, item, p_keep):tgt_cloud = self.data[item, ...]if self.split != 'train':np.random.seed(item)R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()else:tgt_cloud = flip_pc(tgt_cloud)R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()src_cloud = random_crop(copy.deepcopy(tgt_cloud), p_keep=p_keep[0])src_size = math.ceil(self.npts * p_keep[0])tgt_size = self.nptsif len(p_keep) > 1:tgt_cloud = random_crop(copy.deepcopy(tgt_cloud),p_keep=p_keep[1])tgt_size = math.ceil(self.npts * p_keep[1])src_cloud_points = transform(src_cloud[:, :3], R, t)src_cloud_normal = transform(src_cloud[:, 3:], R)src_cloud = np.concatenate([src_cloud_points, src_cloud_normal],axis=-1)src_cloud = random_select_points(src_cloud, m=src_size)tgt_cloud = random_select_points(tgt_cloud, m=tgt_size)if self.split == 'train' or self.noise:src_cloud[:, :3] = jitter_point_cloud(src_cloud[:, :3])tgt_cloud[:, :3] = jitter_point_cloud(tgt_cloud[:, :3])tgt_cloud, src_cloud = shuffle_pc(tgt_cloud), shuffle_pc(src_cloud)return src_cloud, tgt_cloud, R, tdef __getitem__(self, item):src_cloud, tgt_cloud, R, t = self.compose(item=item,p_keep=self.p_keep)if not self.normal:tgt_cloud, src_cloud = tgt_cloud[:, :3], src_cloud[:, :3]return tgt_cloud, src_cloud, R, tdef __len__(self):return len(self.data)

2.2 模型设计

模型设计如下:
在这里插入图片描述

2.3 loss设计

其主要包含Init_loss、Refine_loss和Ol_loss。
其中Init_loss是用于计算 预测点 云 0 预测点云_0 预测点0与目标点云的mse或mae loss,
Refine_loss用于计算 预测点 云 [ 1 : ] 预测点云_{[1:]} 预测点[1:]与目标点云的加权mae loss
Ol_loss用于计算两个输入点云输出的重叠分数,使两个点云对应点的重叠分数是一样的。
在这里插入图片描述

具体实现代码如上:


import math
import torch
import torch.nn as nn
from utils import square_distsdef Init_loss(gt_transformed_src, pred_transformed_src, loss_type='mae'):losses = {}num_iter = 1if loss_type == 'mse':criterion = nn.MSELoss(reduction='mean')for i in range(num_iter):losses['mse_{}'.format(i)] = criterion(pred_transformed_src[i],gt_transformed_src)elif loss_type == 'mae':criterion = nn.L1Loss(reduction='mean')for i in range(num_iter):losses['mae_{}'.format(i)] = criterion(pred_transformed_src[i],gt_transformed_src)else:raise NotImplementedErrortotal_losses = []for k in losses:total_losses.append(losses[k])losses = torch.sum(torch.stack(total_losses), dim=0)return lossesdef Refine_loss(gt_transformed_src, pred_transformed_src, weights=None, loss_type='mae'):losses = {}num_iter = len(pred_transformed_src)for i in range(num_iter):if weights is None:losses['mae_{}'.format(i)] = torch.mean(torch.abs(pred_transformed_src[i] - gt_transformed_src))else:losses['mae_{}'.format(i)] = torch.mean(torch.sum(weights * torch.mean(torch.abs(pred_transformed_src[i] -gt_transformed_src), dim=-1)/ (torch.sum(weights, dim=-1, keepdim=True) + 1e-8), dim=-1))total_losses = []for k in losses:total_losses.append(losses[k])losses = torch.sum(torch.stack(total_losses), dim=0)return lossesdef Ol_loss(x_ol, y_ol, dists):CELoss = nn.CrossEntropyLoss()x_ol_gt = (torch.min(dists, dim=-1)[0] < 0.05 * 0.05).long() # (B, N)y_ol_gt = (torch.min(dists, dim=1)[0] < 0.05 * 0.05).long() # (B, M)x_ol_loss = CELoss(x_ol, x_ol_gt)y_ol_loss = CELoss(y_ol, y_ol_gt)ol_loss = (x_ol_loss + y_ol_loss) / 2return ol_lossdef cal_loss(gt_transformed_src, pred_transformed_src, dists, x_ol, y_ol):losses = {}losses['init'] = Init_loss(gt_transformed_src,pred_transformed_src[0:1])if x_ol is not None:losses['ol'] = Ol_loss(x_ol, y_ol, dists)losses['refine'] = Refine_loss(gt_transformed_src,pred_transformed_src[1:],weights=None)alpha, beta, gamma = 1, 0.1, 1if x_ol is not None:losses['total'] = losses['init'] + beta * losses['ol'] + gamma * losses['refine']else:losses['total'] = losses['init'] + losses['refine']return losses

3、训练与预测

先进入src目录,并将modelnet40_ply_hdf5_2048.zip解压在src目录下
在这里插入图片描述

3.1 训练

训练命令及训练输出如下所示

python train.py --root modelnet40_ply_hdf5_2048/ --noise --unseen

python请添加图片描述
在训练过程中会在work_dirs\models\checkpoints目录下生成两个模型文件
在这里插入图片描述

3.2 验证

训练命令及训练输出如下所示

python eval.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --cuda --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

请添加图片描述

3.3 测试

测试训练数据的命令如下

python vis.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

具体配准效果如下所示,其中绿色点云为输入点云,红色点云为参考点云,蓝色点云为配准后的点云。可以看到蓝色点云基本与红色点云重合,可以确定其配准效果十分完好。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.4 处理自己的数据集

基于该项目训练并处理自己数据的教程后续会给出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/210715.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于H5“汉函谷关起点新安县旅游信息系统”设计与实现

目 录 摘 要 1 ABSTRACT 2 第1章 绪论 3 1.1 系统开发背景及意义 3 1.2 系统开发的目标 3 第2章 主要开发技术介绍 5 2.1 H5技术介绍 5 2.2 Visual Studio 技术介绍 5 2.3 SQL Server数据库技术介绍 6 第3章 系统分析与设计 7 3.1 可行性分析 7 3.1.1 技术可行性 7 3.1.2 操作…

HTTP请求

前言 HTTP是应用层的一个协议。实际我们访问一个网页&#xff0c;都会像该网页的服务器发送HTTP请求&#xff0c;服务器解析HTTP请求&#xff0c;返回HTTP响应。如此就是我们获取资源或者上传资源的原理 HTTP请求报头格式 图片来自网络 HTTP请求报头总体有四部分&#xff1a;…

pycharm中绘制一个3D曲线

import numpy as np import matplotlib.pyplot as plt # 中文的设置 import matplotlib as mp1 from mpl_toolkits.mplot3d import Axes3D mp1.rcParams["font.sans-serif"] ["kaiti"] mp1.rcParams["axes.unicode_minus"] False # 数据创建 X…

忽略python运行出现的大量警告

添加以下代码即可 import warnings warnings.filterwarnings(ignore)

制作红木家具3d模型

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 在家居行业中&#xff0c;设计师可以通过在3D建模中添加实际的家具、…

【python】包(package)与模块(module)、import、__name__与__main__

导入模块一般写在程序最前面&#xff0c;且顺序为&#xff1a;内置模块、第三方模块、自定义模块 一、模块&#xff08;module&#xff09;与包&#xff08;package&#xff09; 模块&#xff08;module&#xff09;可以理解为是一个.py文件&#xff0c;import 模块 相当于执行…

应用于智慧园区的AI边缘计算盒子+AI算法软硬一体化方案

工业园区多为生产型和物流型企业&#xff0c;劳动人员密集&#xff0c;外来人口多&#xff0c;农民工多&#xff0c;人员流动大&#xff0c;车流量大&#xff0c;易引发车祸、破坏公共设施和绿化工程等案件; 英码智慧园区方案&#xff0c;可实现100%管理所有出入人员&#xff1…

ViVo小游戏对接sdk

1.安装环境&#xff1a; 电脑环境&#xff1a;adb环境和oppo一样&#xff0c;npm环境和oppo一样 升级npm&#xff1a; npm install -g npm 清除npm缓存&#xff1a;npm cache clean -f 安装vivo初始化小游戏的工具&#xff1a; npm install -g vivo-minigame/cli 解决办法&…

【Linux】:信号(三)捕捉

信号捕捉 一.sigaction1.基本使用2.sa_mask字段 二.可重入函数三.volatile四.SIGCHLD信号 承接上文 果信号的处理动作是用户自定义函数,在信号递达时就调用这个函数,这称为捕捉信号。由于信号处理函数的代码是在用户空间的,处理过程比较复杂,举例如下: 用户程序注册了SIGQUIT信…

Lambda表达式与方法引用

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 引子 先来看一个案例 …

Vue3获取阴历/农历日期

安装插件 pnpm add chinese-lunar-calendar引入阳历/阴历切换函数 import {getLunar} from chinese-lunar-calendarexport function lunarDate(pDate){const year pDate.getFullYear()const month pDate.getMonth() 1const day pDate.getDate()const result getLunar(yea…

房屋租赁出售经纪人入驻小程序平台

一款专为房屋中介开发的小程序平台&#xff0c;支持独立部署&#xff0c;源码交付&#xff0c;数据安全无忧。 核心功能&#xff1a;房屋出租、经纪人独立后台、分佣后台、楼盘展示、房型展示、在线咨询、地址位置配套设施展示。 程序已被很多房屋交易中介体验使用过&#x…

leetcode 287. 寻找重复数

2023.11.29 本题比较朴素得一个思路是利用map集合的key存储nums中的值&#xff0c;value存储对应值出现的次数&#xff0c;然后再遍历这个map集合的value&#xff0c;如果这个value大于1&#xff0c;说明对应的key出现的次数超过了1次&#xff0c;并且题目说这个key唯一&#x…

frp内网穿透

frp内网穿透 内网穿透是一种网络技术&#xff0c;允许您从互联网访问内部网络中的设备或服务&#xff0c;即使这些设备或服务位于防火墙或路由器等网络边界设备之后&#xff0c;也可以实现远程访问。 0x01 功能介绍 frp是一种代理工具&#xff0c;允许用户通过互联网轻松访问其…

mybatis源码(五)springboot pagehelper实现查询分页

1、背景 springboot的pagehelper插件能够实现对mybatis查询的分页管理&#xff0c;而且在使用时只需要提前声明即可&#xff0c;不需要修改已有的查询语句。使用如下&#xff1a; 之前对这个功能一直很感兴趣&#xff0c;但是一直没完整看过&#xff0c;今天准备详细梳理下。按…

Docker下安装Tomcat

目录 Tomcat简介 Tomcat安装 免修改版Tomcat安装 Tomcat简介 Tomcat是Apache软件基金会Jakarta 项目中的一个核心项目&#xff0c;因为Tomcat 技术先进、性能稳定&#xff0c;而且免费&#xff0c;因而深受Java 爱好者的喜爱并得到了部分软件开发商的认可&#xff0c;成为比…

docker 手工redis7.x cluster

IP端口192.168.0.816379/6380192.168.0.826379/6380192.168.0.1146379/6380 mdkir /data/{6379,6380}cat <<END> /data/6379.conf # 端口号 port 6379# 设置客户端连接后进行任何其他指定前需要使用的密码 #requirepass 123456 ## 当master服务设置了密码保护时(用re…

CKafka 一站式搭建数据流转链路,助力长城车联网平台降低运维成本

关于长城智能新能源 长城汽车是一家全球化智能科技公司&#xff0c;业务包括汽车及零部件设计、研发、生产、销售和服务&#xff0c;旗下拥有魏牌、哈弗、坦克、欧拉及长城皮卡。2022年&#xff0c;长城汽车全年销售1,067,523辆&#xff0c;连续7年销量超100万辆。长城汽车面向…

同旺科技 USB TO SPI / I2C --- 调试W5500

所需设备&#xff1a; 内附链接 1、USB转SPI_I2C适配器(专业版); 首先&#xff0c;连接W5500模块与同旺科技USB TO SPI / I2C适配器&#xff0c;如下图&#xff1a; 读取重试时间值寄存器&#xff0c;默认值0x07D0 输出结果与默认值一致&#xff0c;芯片基本功能已经调通&am…

Java-宋红康-(P133-P134)-多线程创建方式(Thread and Runnable)

b站视频 133-多线程-线程创建方式1&#xff1a;继承Thread类_哔哩哔哩_bilibili 目录 3.1 继承Thread 3.1.1 继承Thread类方式 3.1.2 线程的执行流程 3.1.3 线程内存图 3.1.4 run()方法和start()方法 3.1.5 线程名字的设置和获取 3.1.6 获取运行main方法线程的名字 3.…