【三维深度补全模型】PENet

 【版权声明】
本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。   

参考书籍:《人工智能点云处理及深度学习算法》

 本文为专栏《Python三维点云实战宝典》系列文章,专栏介绍地址“【python三维深度学习】python三维点云从基础到深度学习_python3d点云从基础到深度学习-CSDN博客”。配套书籍《人工智能点云处理及深度学习算法》提供更加全面和系统的解析。

        PENet是一个由浙江大学和上海华为于ICRA 2021发布的深度补全模型(Sparse-Depth-Completion),即通过RGB图像和雷达稀疏点云来获取更加稠密的点云。论文题目和地址分别为《PENet: Towards Precise and Efficient Image Guided Depth Completion》和“https://arxiv.org/abs/2103.00783”。该模型采用了coarse-refine结构,即粗补全和精补全(精度微调)相结合,并且模型在粗补全阶段对不同尺度图像、稀疏点云和几何特征进行充分融合以提高模型深度补全精度。另一方面,模型对CSPN++网络卷积操作进行优化以提高模型运行速度。PENet提出时在KITTI深度补全数据集上取得了最好成绩,目前排名仍然靠前。下图是其在paperwithcode官网上的排名情况,地址为“https://paperswithcode.com/sota/depth-completion-on-kitti-depth-completion”。

图 PENet排名情况

1 PENet模型结构

        PENet模型总体结构如下图所示,采用了coarse-refine结构。其深度粗补全网络称为ENet,采用两条主干网络进行深度补全特征提取。两条主干网络均融合了雷达所采集的稀疏点云,区别在于第一条主干网络融合了RGB色彩信息,而第二条网络融合了第一条网络预测的深度结果。主干网络采用了类似UNet结构的编码-解码结构,实现对不同尺度特征进行融合。因此,ENet对特征类型和特征空间都进行了充分融合,以获取更加丰富的深度特征。

图 PENet模型结构

        PENet深度粗补全结果是两条主干分支网络预测结果的融合,即图中Fused Depth。由于点深度信息与其邻近点密切关联,作者采用DA CSPN++网络对粗补全结果进行微调,进一步提高模型预测精度。

2 输入数据

2.1 KITTI数据集下载

        PENet模型官方程序地址为“https://github.com/JUGGHM/PENet_ICRA2021”,本节将结合该程序进行详细介绍。程序中模型输入数据集为KITTI补全数据集,需要分别下载KITTI原始数据和补全数据。

        KITTI原始数据集如下图所示,包含City、Residential、Road、Campus、Person 和Calibration6个类别,下载地址为“https://www.cvlibs.net/datasets/kitti/raw_data.php?type=city”。如需进行完整训练和测试验证,程序需要下载这6个类别下全部数据,共包括138个可用数据。如果仅进行程序学习或验证测试,那么我们下载部分数据即可,例如City类别下的2011_09_26_drive_0001、2011_09_26_drive_0002、2011_09_26_drive_0005和2011_09_26_drive_0009。下载数据解压得到以日期命名的文件夹,如2011_09_26。

图 KITTI原始数据集下载

        KITTI深度补全数据集主要包含稠密点云深度,以提供稀疏点云补全的真实标签,下载地址为“https://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion”。数据集下载内容包括下图所示部分,即annotated depth maps data set (14 GB)、raw LiDaR scans data set (5 GB)、manually selected validation and test data sets (2 GB)、development kit (48 K)。解压后文件下包含了训练和验证样本数据。

图 KITII深度补全数据集下载

2.2 数据集预处理

        PENet输入数据集目录如下所示,需将上述所下载文件整理成该目录结构形式。我们如果仅下载部分原始数据(2011_09_26_drive_0001、2011_09_26_drive_0002、2011_09_26_drive_0005和2011_09_26_drive_0009),那么需要对深度补全数据集进行相应设置。data_depth_annotated和data_depth_velodyne文件夹下训练train文件夹仅保留2011_09_26_drive_0001_sync和2011_09_26_drive_0009_sync,删除其它文件夹。data_depth_annotated和data_depth_velodyne文件夹下验证val文件夹仅保留2011_09_26_drive_0002_sync和2011_09_26_drive_0005_sync,删除其它文件夹。

├── kitti_depth|   ├── depth|   |   ├──data_depth_annotated|   |   |  ├── train|   |   |  ├── val|   |   ├── data_depth_velodyne|   |   |  ├── train|   |   |  ├── val|   |   ├── data_depth_selection|   |   |  ├── test_depth_completion_anonymous|   |   |  |── test_depth_prediction_anonymous|   |   |  ├── val_selection_cropped├── kitti_raw|   ├── 2011_09_26|   ├── 2011_09_28|   ├── 2011_09_29|   ├── 2011_09_30|   ├── 2011_10_03

        完整训练集包括138个文件夹,而这里仅使用如上两个文件夹数据进行模型解析。完整验证集包括1000个样本。

        模型输入数据由rgb、d、gt、g、position和K等6部分组成。

        (1)rgb

        rgb数据来自于KITTI的2号和3号彩色相机,即彩色图像数据。训练集和验证集图片路径分别为“kitti_raw/*/*_sync/image_0[2,3]/*.png”和“data_depth_selection/val_selection_cropped/image/*.png”。原始图片维度为3x375x1242,经过固定裁剪和随机裁剪后维度为3x320x1216。图片像素深度为8bit,像素取值范围0~255。

        (2)d

        d为激光雷达所采集的稀疏点云深度数据,以16位png图片存储,取值范围0~65535。取值除以256可得到深度值,且取值为零的点表示无效点,即未采集到深度数据。训练集和验证集路径分别为“data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png”和“data_depth_selection/val_selection_cropped/velodyne_raw/*.png”。深度图片维度为3x375x1242,经过固定裁剪和随机裁剪后维度为3x320x1216。

        (3)gt

        gt为稠密点云深度的真实标签数据,以16位png图片存储,取值范围0~65535。取值除以256可得到深度值,且取值为零的点表示无效点,即未采集到深度数据。训练集和验证集路径分别为“data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png”和“data_depth_selection/val_selection_cropped/groundtruth_depth/*.png”。深度图片维度为3x375x1242,经过固定裁剪和随机裁剪后维度为3x320x1216。

        (4)g

        g为rgb彩色图像数据转换后的灰度图像数据,维度为1x320x1216。

        (5)position

        positon是图片像素坐标经过归一化后取值,归一化范围为-1~1。由于像素横纵坐标分别进行归一化处理,因而position维度为2x352x1216,并经过随机裁剪后维度为2x320x1216。

xx_channel = xx_channel.astype('float32') / (self.y_dim - 1)#除以最大值,0~1
yy_channel = yy_channel.astype('float32') / (self.x_dim - 1)#除以最大值,0~1
xx_channel = xx_channel*2 - 1#变换到-1~1
yy_channel = yy_channel*2 - 1#变换到-1~1
ret = np.concatenate([xx_channel, yy_channel], axis=-1)#拼接

        (6)K

        K为3x3维度相机内参矩阵,包含了x、y方向上焦距和光心偏移信息,用于像素坐标和相机坐标系间坐标变换。除直接从calib_cam_to_cam.txt标定文件中读取原始内参矩阵之外,此时K矩阵还需要根据图像裁剪情况对光心偏移进行调整。

def load_calib():"""Temporarily hardcoding the calibration matrix using calib file from 2011_09_26"""calib = open("dataloaders/calib_cam_to_cam.txt", "r")lines = calib.readlines()P_rect_line = lines[25]Proj_str = P_rect_line.split(":")[1].split(" ")[1:]Proj = np.reshape(np.array([float(p) for p in Proj_str]), (3, 4)).astype(np.float32)K = Proj[:3, :3]  # camera matrix# note: we will take the center crop of the images during augmentation# that changes the optical centers, but not focal lengths# K[0, 2] = K[0, 2] - 13  # from width = 1242 to 1216, with a 13-pixel cut on both sides# K[1, 2] = K[1, 2] - 11.5  # from width = 375 to 352, with a 11.5-pixel cut on both sidesK[0, 2] = K[0, 2] - 13;K[1, 2] = K[1, 2] - 11.5;return K

3 ENet主干网络

        ENet主干网络包含两条分支,其中一条支路是图像rgb和稀疏深度d融合对稠密深度的预测,另一条支路是预测结果进一步与稀疏深度d融合并对稠密深度进行再次预测。两条支路预测结果融合得到ENet对稠密深度最终预测结果。

3.1 ENet主干支路一

        程序首先通过平均值池化对position(2x320x1216)进行下采样,采样倍数分别为2、4、8、16、32,从而得到6种不同尺度分辨率的像素坐标(vnorm_sx和unorm_sx)。同样地,激光雷达稀疏深度d也采用最大值池化得到相应分辨率下的深度图d_sx。像素坐标与相机坐标系的对应关系可通过如下公式进行计算,那么根据像素坐标和深度坐标可计算得到目标在相机坐标系下的空间坐标(x,y,z)。

        程序相应函数为GeometryFeature,具体计算过程如下所示。由于position坐标已归一化到-1~1,因此需要结合图片尺寸恢复出像素坐标绝对值,然后使用内参和距离参数得到相机坐标,并称该坐标为几何特征。6种分辨率下的像素坐标和稀疏深度分别进行计算,从而得到6种不同分辨率的几何特征geo_sx。

x = z*(0.5*h*(vnorm+1)-ch)/fh
y = z*(0.5*w*(unorm+1)-cw)/fw
return torch.cat((x, y, z),1)

        第一条主干支路输入的图像rgb和深度d拼接并经过卷积Conv2d(4, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)运算后得到32x320x1216维特征rgb_feature。程序将rgb特征和几何特征按照如下图所示过程逐步进行特征融合与特征提取(特征编码),进而得到不同尺度融合特征rgb_featurex。

Concate(rgb_feature_x, geo_sx)->Conv(s=2) 1
Concate(1, geo_sx) Conv1 2
Downsample(rgb_feaure) 3
Add(rgb_feature_sx, 3) rgb_feature_x+1
Concate(rgb_feature_x+1, geo_sx+1)->Conv(s=1) 4
Concate(4, geo_sx+1) Conv1  5
Add(rgb_feature_sx+1, 5) rgb_feature_x+2

图 图像特征与几何特征融合

        rgb特征与几何特征融合过程如下:

  1. rgb输入特征rgb_feature_x维度为C1xH1xW1,两种尺度几何特征geo_sx和geo_sx+1维度分别为3xH1xW1和3xH2xW2,且H2=H1/2、W2=W1/2。
  2. rgb_feature_x与geo_sx进行拼接后经过卷积Conv(C1+3, 2*C1, 3, 2)得到2*C1xH2xW2维度特征。
  3. (2)中特征进一步与geo_sx+1拼接并经过卷积Conv(2*C1+3, 2*C1, 3, 1)得到2*C1xH2xW2维度特征。
  4. rgb_feature_x与geo_sx进行拼接后经过卷积Conv(C1+3, 2*C1, 3, 2)下采样得到2*C1xH2xW2维度特征。
  5. (3)和(4)中特征进行求和得到融合后rgb特征rgb_feature_x+1,维度为2*C1xH2xW2。
  6. rgb_feature_x+1与geo_sx+1进行拼接后经过卷积Conv(2*C1+3, 2*C1, 3, 1)得到2*C1xH2xW2维度特征。
  7. (6)中特征进一步与geo_sx+1拼接并经过卷积Conv(2*C1+3, 2*C1, 3, 1)得到2*C1xH2xW2维度特征。
  8. (7)中特征和rgb_feature_x+1进行求和得到新的融合后rgb特征rgb_feature_x+2,维度为2*C1xH2xW2。

        从上述步骤可以看到,rgb特征与几何特征进行多次融合,以获取更加充分的几何特征信息。融合后rgb特征rgb_feature10、rgb_feature8、rgb_feature6、rgb_feature4、rgb_feature2、rgb_feature的维度分别为1024x10x38、512x20x76、256x40x152、128x80x304、64x160x608、32x320x1216。

        rgb特征解码阶段从最小尺度rgb特征逐步通过逆卷积上采样与特征融合得到解码后的不同尺度rgb特征,分别为rgb_feature8_plus(512x20x76)、rgb_feature6_plus(256x40x152)、rgb_feature4_plus(128x80x304)、rgb_feature2_plus(64x160x608)、rgb_feature0_plus(32x320x1216)。rgb_feature0_plus经过逆卷积ConvTranspose2d(32, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)得到2x320x1216维度预测结果,这两个维度分别为第一条主干支路深度预测结果(rgb_depth,1x320x1216)及其置信度(rgb_conf,1x320x1216)。

        以上为第一条主干分支的模型处理过程,关键程序解析如下所示。

vnorm_s2 = self.pooling(vnorm)#每两个点平均池化,160x608
vnorm_s3 = self.pooling(vnorm_s2)#每两个点平均池化,80x304
vnorm_s4 = self.pooling(vnorm_s3)#每两个点平均池化,40x152
vnorm_s5 = self.pooling(vnorm_s4)#每两个点平均池化,20x76
vnorm_s6 = self.pooling(vnorm_s5)#每两个点平均池化,10x38
unorm_s2 = self.pooling(unorm)#每两个点平均池化,160x608
unorm_s3 = self.pooling(unorm_s2)#每两个点平均池化,80x304
unorm_s4 = self.pooling(unorm_s3)#每两个点平均池化,40x152
unorm_s5 = self.pooling(unorm_s4)#每两个点平均池化,20x76
unorm_s6 = self.pooling(unorm_s5)#每两个点平均池化,10x38
#不同尺度深度图
valid_mask = torch.where(d>0, torch.full_like(d, 1.0), torch.full_like(d, 0.0))#深度大于0的点为有效点
d_s2, vm_s2 = self.sparsepooling(d, valid_mask)#深度最大值池化,160x608
d_s3, vm_s3 = self.sparsepooling(d_s2, vm_s2)#深度最大值池化,80x304
d_s4, vm_s4 = self.sparsepooling(d_s3, vm_s3)#深度最大值池化,40x152
d_s5, vm_s5 = self.sparsepooling(d_s4, vm_s4)#深度最大值池化,20x76
d_s6, vm_s6 = self.sparsepooling(d_s5, vm_s5)#深度最大值池化,10x38
geo_s1 = self.geofeature(d, vnorm, unorm, 352, 1216, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x320x1216
geo_s2 = self.geofeature(d_s2, vnorm_s2, unorm_s2, 352 / 2, 1216 / 2, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x160x608
geo_s3 = self.geofeature(d_s3, vnorm_s3, unorm_s3, 352 / 4, 1216 / 4, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x80x304
geo_s4 = self.geofeature(d_s4, vnorm_s4, unorm_s4, 352 / 8, 1216 / 8, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x40x152
geo_s5 = self.geofeature(d_s5, vnorm_s5, unorm_s5, 352 / 16, 1216 / 16, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x20x76
geo_s6 = self.geofeature(d_s6, vnorm_s6, unorm_s6, 352 / 32, 1216 / 32, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x10x38
rgb_feature = self.rgb_conv_init(torch.cat((rgb, d), dim=1))#rgbd特征提取,4x320x1216 -> 32x320x1216
rgb_feature1 = self.rgb_encoder_layer1(rgb_feature, geo_s1, geo_s2) #64x160x608,不同尺度rgb与坐标特征融合
rgb_feature2 = self.rgb_encoder_layer2(rgb_feature1, geo_s2, geo_s2) #64x160x608,不同尺度rgb与坐标特征融合
rgb_feature3 = self.rgb_encoder_layer3(rgb_feature2, geo_s2, geo_s3) #128x80x304,不同尺度rgb与坐标特征融合
rgb_feature4 = self.rgb_encoder_layer4(rgb_feature3, geo_s3, geo_s3) #128x80x304,不同尺度rgb与坐标特征融合
rgb_feature5 = self.rgb_encoder_layer5(rgb_feature4, geo_s3, geo_s4) #256x40x152,不同尺度rgb与坐标特征融合
rgb_feature6 = self.rgb_encoder_layer6(rgb_feature5, geo_s4, geo_s4) #256x40x152,不同尺度rgb与坐标特征融合
rgb_feature7 = self.rgb_encoder_layer7(rgb_feature6, geo_s4, geo_s5) #512x20x76,不同尺度rgb与坐标特征融合
rgb_feature8 = self.rgb_encoder_layer8(rgb_feature7, geo_s5, geo_s5) #512x20x76,不同尺度rgb与坐标特征融合
rgb_feature9 = self.rgb_encoder_layer9(rgb_feature8, geo_s5, geo_s6) #1024x10x38,不同尺度rgb与坐标特征融合
rgb_feature10 = self.rgb_encoder_layer10(rgb_feature9, geo_s6, geo_s6) #1024x10x38,不同尺度rgb与坐标特征融合
rgb_feature_decoder8 = self.rgb_decoder_layer8(rgb_feature10)#逆卷积上采样,512x20x76
rgb_feature8_plus = rgb_feature_decoder8 + rgb_feature8#特征融合,512x20x76
rgb_feature_decoder6 = self.rgb_decoder_layer6(rgb_feature8_plus)#逆卷积上采样,256x40x152
rgb_feature6_plus = rgb_feature_decoder6 + rgb_feature6#特征融合,256x40x152
rgb_feature_decoder4 = self.rgb_decoder_layer4(rgb_feature6_plus)#逆卷积上采样,128x80x304
rgb_feature4_plus = rgb_feature_decoder4 + rgb_feature4#特征融合,128x80x304
rgb_feature_decoder2 = self.rgb_decoder_layer2(rgb_feature4_plus)#逆卷积上采样,64x160x608
rgb_feature2_plus = rgb_feature_decoder2 + rgb_feature2#特征融合,64x160x608
rgb_feature_decoder0 = self.rgb_decoder_layer0(rgb_feature2_plus)#逆卷积上采样,32x320x1216
rgb_feature0_plus = rgb_feature_decoder0 + rgb_feature#特征融合,32x320x1216
rgb_output = self.rgb_decoder_output(rgb_feature0_plus)#深度和置信度预测,2x320x1216
rgb_depth = rgb_output[:, 0:1, :, :]#1x320x1216
rgb_conf = rgb_output[:, 1:2, :, :]#1x320x1216

3.2 ENet主干支路二

        ENet第二条主干支路输入为稀疏深度d和支路一预测深度rgb_depth,二者拼接并经过卷积Conv2d(2, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)得到32x320x1216维度融合特征。模型将该特征定义为稀疏特征,即sparsed_feature。该支路仍然采用特征编码-解码的结构进行特征提取。

        与支路一操作类似,不同尺度下稀疏特征也与几何特征进行两次融合。除此之外,稀疏特征sparsed_featurex还与相同尺度的rgb特征rgb_featurex_plus进行拼接融合。稀疏特征、几何特征和rgb特征相互融合,完成特征编码,主要输出为sparsed_feature10(1024x10x38)、sparsed_feature8(512x20x76)、sparsed_feature6(256x40x152)、sparsed_feature4(128x80x304)、sparsed_feature2(64x160x608)。

        稀疏特征解码阶段从最小尺度稀疏特征逐步通过逆卷积上采样与特征融合得到解码后的不同尺度稀疏特征,分别为decoder_feature1(512x20x76)、decoder_feature2(256x40x152)、decoder_feature3(128x80x304)、decoder_feature4(64x160x608)、decoder_feature5(32x320x1216)。rgb_feature0_plus经过逆卷积ConvTranspose2d(32, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)得到2x320x1216维度预测结果,这两个维度分别为第二条主干支路深度预测结果(d_depth,1x320x1216)及其置信度(d_conf,1x320x1216)。

        以上为第二条主干分支的模型处理过程,关键程序解析如下所示。

sparsed_feature = self.depth_conv_init(torch.cat((d, rgb_depth), dim=1))#雷达特征与RGB预测深度融合后提取特征,32x320x1216
sparsed_feature1 = self.depth_layer1(sparsed_feature, geo_s1, geo_s2)#类似深度信息与几何坐标信息融合,64x160x608
sparsed_feature2 = self.depth_layer2(sparsed_feature1, geo_s2, geo_s2) #64x160x608
sparsed_feature2_plus = torch.cat([rgb_feature2_plus, sparsed_feature2], 1)#128x160x608
sparsed_feature3 = self.depth_layer3(sparsed_feature2_plus, geo_s2, geo_s3) #128x80x304
sparsed_feature4 = self.depth_layer4(sparsed_feature3, geo_s3, geo_s3)#128x80x304
sparsed_feature4_plus = torch.cat([rgb_feature4_plus, sparsed_feature4], 1)#256x80x304
sparsed_feature5 = self.depth_layer5(sparsed_feature4_plus, geo_s3, geo_s4) #256x40x152
sparsed_feature6 = self.depth_layer6(sparsed_feature5, geo_s4, geo_s4) #256x40x152
sparsed_feature6_plus = torch.cat([rgb_feature6_plus, sparsed_feature6], 1)#512x40x152
sparsed_feature7 = self.depth_layer7(sparsed_feature6_plus, geo_s4, geo_s5) #512x20x76
sparsed_feature8 = self.depth_layer8(sparsed_feature7, geo_s5, geo_s5) #512x20x76
sparsed_feature8_plus = torch.cat([rgb_feature8_plus, sparsed_feature8], 1)#1024x20x76
sparsed_feature9 = self.depth_layer9(sparsed_feature8_plus, geo_s5, geo_s6) #1024x10x38
sparsed_feature10 = self.depth_layer10(sparsed_feature9, geo_s6, geo_s6) #1024x10x38
fusion1 = rgb_feature10 + sparsed_feature10#1024x10x38
decoder_feature1 = self.decoder_layer1(fusion1)#逆卷积上采样,512x20x76
fusion2 = sparsed_feature8 + decoder_feature1#特征融合,512x20x76
decoder_feature2 = self.decoder_layer2(fusion2)#逆卷积上采样,256x40x152
fusion3 = sparsed_feature6 + decoder_feature2#特征融合,256x40x152
decoder_feature3 = self.decoder_layer3(fusion3)#逆卷积上采样,128x80x304
fusion4 = sparsed_feature4 + decoder_feature3#特征融合,128x80x304
decoder_feature4 = self.decoder_layer4(fusion4)#逆卷积上采样,64x160x608
fusion5 = sparsed_feature2 + decoder_feature4#特征融合,64x160x608
decoder_feature5 = self.decoder_layer5(fusion5)#逆卷积上采样,32x320x1216
depth_output = self.decoder_layer6(decoder_feature5)#卷积,2x320x1216
d_depth, d_conf = torch.chunk(depth_output, 2, dim=1)#1x320x1216,1x320x1216

3.3 分支融合

        ENet两条支路均预测了深度及其置信度,其中第一条支路预测结果为深度(rgb_depth,1x320x1216)及其置信度(rgb_conf,1x320x1216);第二条支路预测结果为深度(d_depth,1x320x1216)及其置信度(d_conf,1x320x1216)。融合时最终预测深度来源于两条支路预测深度的加权求和,权重由置信度经过softmax得到,即置信度概率越大,权重占比越大。关键程序解析如下所示。

rgb_conf, d_conf = torch.chunk(self.softmax(torch.cat((rgb_conf, d_conf), dim=1)), 2, dim=1)#将两条支路的置信度转换为权重
output = rgb_conf*rgb_depth + d_conf*d_depth#深度预测结果,1x320x1216

        模型返回值为rgb_depth、d_depth、output(融合预测深度)。

3.4 ENet损失函数

        ENet训练损失包含rgb深度损失、稀疏深度损失和融合深度损失,对应预测结果为rgb_depth、d_depth、output。其损失函数均为MaskedMSELoss,衡量预测深度与真实深度标签gt之间的偏差。

        训练前两个迭代周期中,rgb深度损失和稀疏深度损失的权重为0.2,并在第3~4个周期内降为0.05。从第5个训练周期开始,ENet训练损失函数仅包括融合深度损失depth_loss。

        ENet训练损失关键程序解析如下所示。

st1_pred, st2_pred, pred = model(batch_data)#rgb_depth、d_depth、output(融合预测深度)
round1, round2, round3 = 1, 3, None
if(actual_epoch <= round1):w_st1, w_st2 = 0.2, 0.2
elif(actual_epoch <= round2):w_st1, w_st2 = 0.05, 0.05else:w_st1, w_st2 = 0, 0
depth_loss = depth_criterion(pred, gt)#MaskedMSELoss()
st1_loss = depth_criterion(st1_pred, gt)#MaskedMSELoss()
st2_loss = depth_criterion(st2_pred, gt)#MaskedMSELoss()
loss = (1 - w_st1 - w_st2) * depth_loss + w_st1 * st1_loss + w_st2 * st2_loss

4 DA CSPN++

        DA (dilated and acceleratedm,膨胀加速)CSPN++网络是对ENet预测结果进行微调以获取更加准确的深度信息。其输入包括ENet所提取特征(feature_s1 64x320x1216,feature_s2 128x160x608)与深度预测结果(coarse_depth,1x320x1216),其中特征feature_s1 和feature_s2是rgb深度特征和融合特征的融合。根据膨胀比例,模型设置相应尺度的输入特征。卷积膨胀的作用是为了使卷积核覆盖范围更大,从而使卷积视野范围更广。从另外一个角度上来说,卷积膨胀相当于在下采样的特征图上进行普通卷积操作,该模型的后续操作便是采用这种方法。因此,假设膨胀系数为2,那么所需特征图尺寸为160x608。模型输入特征包含两部分,一部分为原始尺度特征feature_s1,即rgb_feature0_plus和 decoder_feature5拼接融合,维度为64x320x1216;另一部分为用于膨胀操作的特征feature_s2,即rgb_feature2_plus和 decoder_feature4拼接融合,维度为128x160x608。

#ENet输出
torch.cat((rgb_feature0_plus, decoder_feature5), 1), torch.cat((rgb_feature2_plus, decoder_feature4),1), output
feature_s1, feature_s2, coarse_depth = self.backbone(input)#由ENet得到的特征与预测深度,64x320x1216,128x160x608,1x320x1216
depth = coarse_depth#1x320x1216

        CSPN++网络核心思想是采用模型来自主学习卷积核权重,而不是使用卷积直接对输入进行操作,这一点类似于transformer的QK操作。DA CSPN++用于学习卷积核权重的输入特征为feature_s2(128x160x608)。CSPN++网络的另一个特点为采用多种尺度卷积核进行特征提取,参考程序使用了尺寸为3、5、7的卷积核。每个卷积核所提取特征采用加权求和的方法进行融合,其中置信度权重网路的输入也为feature_s2。

        作者对CSPN++网络进行了加速设计,将卷积操作转换为矩阵乘法,从而实现并行计算。例如,3x3卷积核在HxW维度特征图进行滑动操作可转换为9xHxW维度卷积和9xHxW特征图的矩阵乘法。

图 CSPN++加速

        模型对原始输入特征和膨胀特征均会进行CSPN++微调,主要包括卷积核参数及其权重学习、DA CSPN++结果微调、feature_s1 CSPN++结果微调、特征加权求和融合等步骤。

4.1 卷积核参数及其权重学习

        feature_s2(128x160x608)经过卷积 Conv2d(128, 3)和softmax后得到3x160x608维度卷积核权重,对应kernel_conf3_s2(1x160x608)、kernel_conf5_s2(1x160x608)、kernel_conf7_s2(1x160x608)。另一方面,feature_s2经过卷积 Conv2d和padding操作得到卷积核参数,分别为guide3_s2(9x162x610)、guide5_s2(25x164x612)、guide7_s2(49x166x614)。

        feature_s1(64x320x1216)经过卷积 Conv2d(64, 3)和softmax后得到3x320x1216维度卷积核权重,对应kernel_conf3(1x320x1216)、kernel_conf5(1x320x1216)、kernel_conf7(1x320x1216)。另一方面,feature_s1经过卷积 Conv2d和padding操作得到卷积核参数,分别为guide3(9x322x1218)、guide5(25x324x1220)、guide7(49x326x1222)。

kernel_conf_s2 = self.kernel_conf_layer_s2(feature_s2)#128x320x1216 -> 3x160x608
kernel_conf_s2 = self.softmax(kernel_conf_s2)#转换为权重,3x160x608,通道维度和为1,即不同卷积核特征的置信度权重
kernel_conf3_s2 = kernel_conf_s2[:, 0:1, :, :]#1x160x608,3x3卷积核特征权重
kernel_conf5_s2 = kernel_conf_s2[:, 1:2, :, :]#1x160x608,5x5卷积核特征权重
kernel_conf7_s2 = kernel_conf_s2[:, 2:3, :, :]#1x160x608,7x7卷积核特征权重
kernel_conf = self.kernel_conf_layer(feature_s1)#64x320x1216 -> 3x320x1216
kernel_conf = self.softmax(kernel_conf)#转换为权重,3x320x1216,通道维度和为1,即不同卷积核特征的置信度权重
kernel_conf3 = kernel_conf[:, 0:1, :, :]#1x320x1216,3x3卷积核特征权重
kernel_conf5 = kernel_conf[:, 1:2, :, :]#1x320x1216,5x5卷积核特征权重
kernel_conf7 = kernel_conf[:, 2:3, :, :]#1x320x1216,7x7卷积核特征权重
guide3_s2 = self.iter_guide_layer3_s2(feature_s2)#学习3x3卷积CSPN,9x162x610
guide5_s2 = self.iter_guide_layer5_s2(feature_s2)#学习5x5卷积CSPN,25x164x612
guide7_s2 = self.iter_guide_layer7_s2(feature_s2)#学习7x7卷积CSPN,49x166x614
guide3 = self.iter_guide_layer3(feature_s1)#学习3x3卷积CSPN,9x322x1218
guide5 = self.iter_guide_layer5(feature_s1)#学习3x3卷积CSPN,25x324x1220
guide7 = self.iter_guide_layer7(feature_s1)#学习3x3卷积CSPN,49x326x1222

4.2 DA CSPN++结果微调

        ENet预测深度(1x320x1216)下采样成4张子深度图(1x160x608),由于子深度图可构成完整原始深度图,因而这种下采样不会带来信息丢失。DA CSPN++对这4种特征图(depth_s2_00、depth_s2_01、depth_s2_10、depth_s2_11)分别进行深度微调结果预测。每个子深度图进行6次连续CSPN++操作以利用更深层次特征来预测新的微调深度,并且每次进行CSPN++操作时都会与ENet子深度图和激光雷达稀疏深度图d_s2进行融合。子深度图预测结果为3种卷积核提取特征的加权求和。DA CSPN++预测深度(depth_s2_00、depth_s2_01、depth_s2_10、depth_s2_11)重新拼接成原始尺寸,即depth_s2(1x320x1216)。

d_s2, valid_mask_s2 = self.downsample(d, valid_mask)#原始雷达深度最大值池化,1x160x608
mask_s2 = self.mask_layer_s2(feature_s2)#128x320x1216 -> 1x160x608
mask_s2 = torch.sigmoid(mask_s2)#转化为权重, 1x160x608,即DA CSPN++输出的权重
mask_s2 = mask_s2*valid_mask_s2#深度mask与预测mask相乘,1x160x608
feature_12 = torch.cat((feature_s1, self.upsample(self.dimhalf_s2(feature_s2))), 1)#128x320x1216,两种输入特征融合
att_map_12 = self.softmax(self.att_12(feature_12))#2x320x1216,用于ENet预测深度和DA CSPN++微调深度融合
depth_s2 = depth#1x320x1216
depth_s2_00 = depth_s2[:, :, 0::2, 0::2]#深度图拆分,1x160x608
depth_s2_01 = depth_s2[:, :, 0::2, 1::2]#深度图拆分,1x160x608
depth_s2_10 = depth_s2[:, :, 1::2, 0::2]#深度图拆分,1x160x608
depth_s2_11 = depth_s2[:, :, 1::2, 1::2]#深度图拆分,1x160x608
depth3_s2_00 = self.CSPN3(guide3_s2, depth3_s2_00, depth_s2_00_h0)#1x160x608,CSPN特征提取
depth3_s2_00 = mask_s2*d_s2 + (1-mask_s2)*depth3_s2_00#1x160x608,与原始输入稀疏特征加权求和融合
depth5_s2_00 = self.CSPN5(guide5_s2, depth5_s2_00, depth_s2_00_h0)#1x160x608,CSPN特征提取
depth5_s2_00 = mask_s2*d_s2 + (1-mask_s2)*depth5_s2_00#1x160x608,与原始输入稀疏特征加权求和融合
depth7_s2_00 = self.CSPN7(guide7_s2, depth7_s2_00, depth_s2_00_h0)#1x160x608,CSPN特征提取
depth7_s2_00 = mask_s2*d_s2 + (1-mask_s2)*depth7_s2_00#1x160x608,与原始输入稀疏特征加权求和融合
depth_s2_00 = kernel_conf3_s2*depth3_s2_00 + kernel_conf5_s2*depth5_s2_00 + kernel_conf7_s2*depth7_s2_00#不同卷积核特征加权求和融合,1x160x608
depth_s2[:, :, 0::2, 0::2] = depth_s2_00#将深度重新拼接成原始尺度,1x320x1216
refined_depth_s2 = depth*att_map_12[:, 0:1, :, :] + depth_s2*att_map_12[:, 1:2, :, :]#与ENet深度加权求和融合,1x320x1216

4.3 feature_s1 CSPN++结果微调

        模型再次使用feature_s1学习的三种卷积核参数对DA CSPN++的预测深度结果depth_s2进行结果微调。模型此时同样采用连续6次CSPN++操作,,并且每次进行CSPN++操作时都会与depth_s2和激光雷达稀疏深度图d进行融合。三种尺寸卷积核对应的CSPN预测结果(depth3、depth5、depth7)加权求和即可得到模型最终微调后的预测深度refined_depth(1x320x1216)。

mask = self.mask_layer(feature_s1)#64x320x1216 -> 1x320x1216
mask = torch.sigmoid(mask)#转化为权重,1x320x1216,非膨胀CSPN++卷积输出的权重
mask = mask*valid_mask#深度mask与预测mask相乘,1x320x1216
for i in range(6):depth3 = self.CSPN3(guide3, depth3, depth)#采用CSPN再次进行深度微调,1x320x1216depth3 = mask*d + (1-mask)*depth3#与原始输入稀疏特征加权求和融合,1x320x1216depth5 = self.CSPN5(guide5, depth5, depth)#采用CSPN再次进行深度微调,1x320x1216depth5 = mask*d + (1-mask)*depth5#与原始输入稀疏特征加权求和融合,1x320x1216depth7 = self.CSPN7(guide7, depth7, depth)#采用CSPN再次进行深度微调,1x320x1216depth7 = mask*d + (1-mask)*depth7#与原始输入稀疏特征加权求和融合,1x320x1216
refined_depth = kernel_conf3*depth3 + kernel_conf5*depth5 + kernel_conf7*depth7#深度加权求和融合,1x320x1216

4.4 损失函数

        DA CSPN++阶段训练损失函数仅由depth_loss组成,即CsPN++深度预测结果与真实标签之间偏差,损失函数类型为MaskedMSELoss。

5 模型训练

        PENet模型训练包括三个步骤,分别是ENet训练、DA CSPN++训练和PENet训练,分别对应下图中I、II、III。

图PENet训练示意图

        ENet训练命令为“CUDA_VISIBLE_DEVICES="0,1" python main.py -b 6 -n e”,CUDA_VISIBLE_DEVICES="0,1"部分可以根据实际情况设置GPU序号。作者提供的Enet预训练模型下载地址为“https://drive.google.com/file/d/1TRVmduAnrqDagEGKqbpYcKCT307HVQp1/view?usp=sharing”。

        DA-CSPN++训练命令为“CUDA_VISIBLE_DEVICES="0,1" python main.py -b 6 -f -n pe --resume [enet-checkpoint-path]”。“-f”表示训练DA-CSPN++网络时ENet主干网络是固定的,即不通过梯度传播更新参数。

        当ENet和DA-CSPN++分别训练完成后,模型再次进行整体训练,训练命令为“CUDA_VISIBLE_DEVICES="0,1" python main.py -b 10 -n pe -he 160 -w 576 --resume [penet-checkpoint-path]”。该训练程序与DA CSPN++训练的区别在于不再使用-f参数,即ENet主干网络也需要进行训练更新。预训练模型可以使用DA-CSPN++训练得到的模型,也可直接使用作者提供的PENet预训练模型,下载地址为“https://drive.google.com/file/d/1RDdKlKJcas-G5OA49x8OoqcUDiYYZgeM/view?usp=sharing”。

6 【python三维深度学习】python三维点云从基础到深度学习_python3d点云从基础到深度学习-CSDN博客

【版权声明】
本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。 
 

更多python与C++技巧、三维算法、深度学习算法总结、大模型请关注我的博客,欢迎讨论与交流:https://blog.csdn.net/suiyingy,或”乐乐感知学堂“公众号。Python三维领域专业书籍推荐:《人工智能点云处理及深度学习算法》。

 本文为专栏《Python三维点云实战宝典》系列文章,专栏介绍地址“【python三维深度学习】python三维点云从基础到深度学习_python3d点云从基础到深度学习-CSDN博客”。配套书籍《人工智能点云处理及深度学习算法》提供更加全面和系统的解析。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/407767.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

shell脚本中$0 $1 $# $@ $* $? $$ 的各种符号意义详解

文章目录 一、概述1.1、普通字符1.2、元字符 二、转义字符$2.1、实例12.2、实例22.3、实例32.4、实例42.5、实例5 三、linux命令执行返回值$?说明 一、概述 shell中有两类字符&#xff1a;普通字符、元字符。 1.1、普通字符 在Shell中除了本身的字面意思外没有其他特殊意义…

校友林小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;树木管理管理&#xff0c;所属科管理&#xff0c;树木领取管理&#xff0c;树跟踪状态管理&#xff0c;用户信息统计管理&#xff0c;树木捐款管理&#xff0c;留言板管理 微信端…

基于vue框架的毕业设计管理系统5n36i(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;学生,教师,课题信息,题目分类,选题信息,任务书,中期检查,提交论文,论文成绩,答辩成绩,校园公告,教研主任,申报课题 开题报告内容 基于Vue框架的毕业设计管理系统开题报告 一、引言 随着高等教育的不断发展&#xff0c;毕业设计作为培…

AITDK SEO扩展:为网站优化提供一站式解决方案

AITDK SEO扩展&#xff1a;为网站优化提供一站式解决方案 想提升你的网站在搜索引擎中的排名&#xff1f;让我们来看看AITDK SEO扩展&#xff0c;它是你网站优化的得力助手&#xff01;在这篇文章中&#xff0c;我将为你介绍AITDK SEO扩展的功能特点&#xff0c;以及它如何帮助…

警惕!低血糖来袭,这些“隐形信号”你中招了吗?

在这个快节奏的时代&#xff0c;我们往往忙于工作、学习与生活&#xff0c;却容易忽视身体发出的微妙警告。其中&#xff0c;低血糖作为一种常见但易被忽视的健康问题&#xff0c;正悄悄影响着许多人的生活质量。今天&#xff0c;就让我们一起揭开低血糖的神秘面纱&#xff0c;…

Java:包装类

文章目录 引入原因包装类代码演示包装类的其他常见操作 使用到的有关ArrayList的方法 引入原因 泛型和集合不支持基本数据类型&#xff0c;只能支持引用数据类型 包装类 包装类就是把基本类型的数据包装成对象 就是说不再是一个int类型的数&#xff0c;而是一个Integer类型的…

Stable Diffusion 使用详解(8)--- layer diffsuion

背景 layer diffusion 重点在 layer&#xff0c;顾名思义&#xff0c;就是分图层的概念&#xff0c;用过ps 的朋友再熟悉不过了。没使用过的&#xff0c;也没关系&#xff0c;其实很简单&#xff0c;本质就是各图层自身的编辑不会影响其他图层&#xff0c;这好比OS中运行了很多…

文件树控件开发

文件树控件和获取驱动信息功能 然后添加上查看文件信息的按钮 双击这个按钮添加上如下代码 void CRemoteClientDlg::OnBnClickedBtnFileinfo() {int ret SendCommandPacket(1);if (ret -1) {AfxMessageBox(_T("命令处理失败!!!"));return;}ClientSocket* pClient…

AI大模型独角兽 MiniMax 基于 Apache Doris 升级日志系统,PB 数据秒级查询响应

作者&#xff1a;MiniMax 基础架构研发工程师 Koyomi、香克斯、Tinker 导读&#xff1a;早期 MiniMax 基于 Grafana Loki 构建了日志系统&#xff0c;在资源消耗、写入性能及系统稳定性上都面临巨大的挑战。为此 MiniMax 开始寻找全新的日志系统方案&#xff0c;并基于 Apache …

Ubuntu 22安装和配置PyCharm详细教程(图文详解)

摘要&#xff1a;本文提供了在 Ubuntu 22 上通过官方 .tar.gz 文件安装 PyCharm 的详细教程。包括从 JetBrains 官方网站下载适合的 PyCharm 版本&#xff08;Community 或 Professional&#xff09;&#xff0c;在终端中解压并将其移动到 /opt 目录&#xff0c;配置适当的权限…

【C++题解】1147. 求1/1+1/2+2/3+3/5+5/8+8/13+13/21……的前n项的和

欢迎关注本专栏《C从零基础到信奥赛入门级&#xff08;CSP-J&#xff09;》 问题&#xff1a;1147. 求1/11/22/33/55/88/1313/21……的前n项的和 类型&#xff1a;函数 题目描述&#xff1a; 求1/11/22/33/55/88/1313/2121/34…的前 n 项的和。 输入&#xff1a; 输入一个…

Unity读取Android本地图片

unity读取Android本地图片 一、安卓读取路径 安卓路径&#xff1a;“file:///storage/emulated/0/”自己图片的路径 例&#xff1a;“file:///storage/emulated/0/small.jpg” 二、unity搭建 使用UI简单搭个界面 三、新建一个脚本 代码内容如下 using System.Collectio…

谷粒商城实战笔记-251-商城业务-消息队列-Exchange类型

文章目录 一&#xff0c;Exchange二&#xff0c;Exchange的四种类型1&#xff0c;direct2&#xff0c;fanout3&#xff0c;topic 三&#xff0c;实操1&#xff0c;创建一个exchange2&#xff0c;创建一个queue3&#xff0c;将queue绑定到exchange 一&#xff0c;Exchange AMQP …

本地部署docker文档

由于访问 https://docs.docker.com/ 文档慢&#xff0c;直接本地部署官方文档 如果不想执行以下操作&#xff0c;也可以直接使用官方文档仓库地址提供的 Dockerfile 和 compose.yaml 进行操作 以下操作环境为Windows系统&#xff0c;根据 Dockerfile 相关操作来生成 html 页面…

金融帝国实验室(Capitalism Lab)官方技术支持中文汉化包_v4.09

<FCT汉化小组>Vol.001号作品 ————————————— ◎ 作品名称&#xff1a;金融帝国实验室&#xff08;Capitalism Lab&#xff09;官方中文汉化包 ◎ 制作发布&#xff1a;FCT汉化小组 ◎ 发布版本&#xff1a;CapLab Simplified Chinese loc v4.09 ◎ 发布时…

记录一次经历:使用flask_sqlalchemy集成flask造成循环导入问题

前言&#xff1a; 工作需求&#xff0c;写一个接口&#xff0c;用Python来编写&#xff0c;我首先想到用flask小型框架来支撑&#xff0c;配置sqlalchemy来实现&#xff0c;但是在实现的过程中&#xff0c;发生循环导入问题 我想到用蓝图来解决此问题&#xff0c;但是仍然会出死…

认知杂谈22

今天分享 有人说的一段争议性的话 I I 私人空间&#xff0c;成长的温床 咱一说到成长啊&#xff0c;可不能小瞧了外部环境对咱的影响。这环境啊&#xff0c;那可不是无关紧要的事儿&#xff0c;实际上呢&#xff0c;它对咱的成长起着特别关键的作用。你就想想看&#xff0c…

NRC-SIM:基于Node-RED的多级多核缓存模拟器

整理自&#xff1a; 《NRC-SIM: A NODE-RED Based Multi-Level, Many-Core Cache Simulator》&#xff0c;由 Ezequiel Trevio 撰写&#xff0c;作为他在德克萨斯大学里奥格兰德河谷分校攻读电气工程硕士学位的部分成果。以下是论文的详细主要内容&#xff1a; 摘要(Abstract…

uni-app01

工具:HuilderX noed版本:node-v17.3.1 npm版本:8.3.0 淘宝镜像:https://registry.npmmirror.com/ 未安装nodejs可以进入这里https://blog.csdn.net/a1241436267/article/details/141326585?spm1001.2014.3001.5501 目录 1.项目搭建​编辑 2.项目结构 3.使用浏览器运行…

鸿蒙OS promptAction的使用

效果如下&#xff1a; import { promptAction } from kit.ArkUIlet customDialogId: number 0Builder function customDialogBuilder() {Column() {Blank().height(30)Text(确定要删除吗&#xff1f;).fontSize(15)Blank().height(40)Row() {Button("取消").onClick…