使用合成数据训练语义分割模型

计算机视觉应用现在在各个科技领域无处不在。 这些模型高效且有效,研究人员每年都会尝试新想法。 新的前沿是试图消除深度学习的最大负担:需要大量的标记数据。 正如本文所述,此问题的解决方案是使用合成数据。

从这些研究中获益最多的计算机视觉领域当然是语义分割领域,即预测图像每个像素的标签的任务,以便从图像中检索感兴趣的对象。 正如人们所预料的那样,手动标记训练集是一个昂贵、耗时且容易出错的过程,因此有多种利用合成数据的新方法。

在本文中,我们将看到其中一种方法,它利用生成对抗网络来解决使用合成数据的域适应问题。另一种常用的合成数据生成方法是利用逼真渲染的游戏引擎,例如基于UE5开发的UnrealSynth合成数据生成器:
在这里插入图片描述

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器

1、合成数据生成

为了生成语义分割任务的数据,最常见的解决方案是使用与渲染引擎关联的模拟器。 通过这种方式,可以随意生成图像,改变闪电条件、物体的数量和姿势以及它们之间的交互,并始终关联像素完美的语义标签。 例如,一个非常流行的数据集,几乎所有研究都用作基准,是 GTAV [1],其中使用的模拟引擎是同名视频游戏。 该数据集包含从汽车驾驶员的角度拍摄的图像,非常适合自动驾驶等应用。 另一个著名的数据集是 SINTHIA [2],它也包含城市环境的图像。

在这里插入图片描述

图 1.1 — 来自 GTAV 数据集 [2] 的带有标签的图像示例

2、领域适应的生成方法

直接使用合成数据训练模型是不够的,神经网络可能会学习模拟环境中存在的一些不真实的模式,无法很好地概括现实世界的数据。 这称为域适应问题(Domain Adaption Problem)。

为了克服这个问题,模型必须在训练过程中学习重新调整源域 S(合成域)和目标域 T(真实域)之间特征分布的最佳方法。 这可以通过对抗性训练、知识蒸馏和自我监督学习等多种方法来实现。

特别是,对抗性训练的特点是采用生成方法,将源域数据转换为更类似于目标域的分布。 它可以表述如下:

给定源域数据集 Dₛ= {(xᵢˢ, yᵢˢ), i=1…nₛ} 和目标域数据集 Dₜ = {xᵢᵗ, i=1…nₜ},其中 xᵢˢ 和 xᵢᵗ 是输入样本, yᵢˢ 是对应的样本 xᵢˢ 的标签,目标是学习一个映射函数 𝓍ᵢˢ = G(xᵢˢ),称为生成器,它将源域特征映射到目标域特征,以便在转换后的源域图像上训练的深度学习模型可以表现良好 在目标域上。 它是通过判别器来完成的,判别器是一种神经网络,它接收真实图像和变换后的合成图像的输入,并尝试预测输入是否来自真实分布。

网络在对抗性环境中进行训练,只有当鉴别器失败时,生成器才会获胜。 当变换后的图像与真实图像非常相似以至于鉴别器无法区分它们时,该过程会收敛,从而使预测不比随机猜测更好(准确度为 50%)。

3、几何引导的输入输出自适应

各种算法都利用生成方法。 其中之一被称为 GIO-Ada [3],代表几何引导输入输出适应。 该算法相对于简单方法引入了 2 项改进。

它使用可以从模拟引擎轻松检索的另一条信息:深度图。 直觉是,对象的几何信息更好地编码在其深度信息中,而不是其像素的语义标签中。 因此,模型被训练来估计输入图像的深度图,并且这个额外的信息仅在训练期间用作辅助损失。
它在输出级别使用第二个对抗阶段,第二个鉴别器对任务网络的输出(语义标签图和几何深度图)进行操作,经过训练以预测预测的输出来自真实的还是合成的 图像。

在这里插入图片描述

图 1.2 — GIO-Ada 架构概述。 源数据的流向以橙线显示,目标数据的流向以黑线显示

完整的架构由 4 个神经网络组成:生成器(用于转换合成图像)、任务网络(预测真实图像和转换图像的标签和深度图)以及 2 个判别器。 所有网络都经过端到端训练,并采用遵循对抗训练规则的通用优化步骤。

4、Pytorch Lightening实现

为了轻松实现和训练这种复杂的算法,pytorch_lightning 是一个可以提供帮助的库。 这是 pytorch 的包装器,有助于避免重新实现一些与 torch 配合使用所需的样板代码,例如实现训练循环、处理超参数和权重的记录和保存、管理 GPU(或多个 GPU)并执行优化器步骤。 在我们的例子中,最后一个功能不是必需的,因为对抗训练的特殊性恰恰在于生成器和判别器之间优化步骤的交替,并且需要定制。

让我们首先导入库并定义一个实用函数,该函数将用于为鉴别器创建标签。

import itertools
from typing import Iteratorimport pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics.classification.jaccard import MulticlassJaccardIndexdef _labels(inputs: torch.Tensor, fill_value: int) -> torch.Tensor:return torch.full((inputs.size(0), 1), fill_value).to(inputs)

神经网络被实现为torch模块。 给定 B = 批量大小、C = 图像通道、K = 类数、W、H= 图像的宽度和高度:

  • 任务网络必须处理形状为 B × C × W × H 的批量图像,并返回形状为 B × K × W × H 的标签预测和形状为 B × 1× W × H 的深度预测。一种可能的架构选择是 使用 DeepLabV3+ [4] 作为任务网络,具有两个不同的头,一个用于类别预测,一个用于深度预测。
  • 图像变换网络必须输入所有合成数据,即形状为 B × C × W × H 的图像、形状为 B × K × W × H 的标签和形状为 B × 1× W × H 的深度图,连接起来 它们,并在输出中生成形状为 B × C × W × H 的变换图像。
  • 鉴别器必须采用形状 B × (C 或 C + K + 1) × W × H 的输入,并产生形状 B × 1 的输出,表示样本为真实样本的概率。
class TaskNetwork(nn.Module):def __init__(self,input_channels: int,num_classes: int,pretrained_backbone: bool = False,) -> None:...def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:...class ImageTransformNetwork(nn.Module):def __init__(self,input_channels: int,output_channels: int,) -> None:...def forward(self,fake_images: torch.Tensor,labels: torch.Tensor,depths: torch.Tensor,) -> torch.Tensor:...class Discriminator(nn.Module):def __init__(self, input_channels: int) -> None:...def forward(self, inputs: torch.Tensor) -> torch.Tensor:...

其余代码将在 LightningModule 内实现。 在这里,我们在 __init__方法中传递所有超参数,在该方法中我们实例化了 4 个神经网络,以及损失和指标。 卷积层的权重从正态分布初始化,任务网络的权重除外,其中权重可以预先训练,例如使用 ImageNet 数据集。

class GIOAda(pl.LightningModule):REAL_LABEL = 1FAKE_LABEL = 0def __init__(self,num_classes: int,pretrained_backbone: bool,init_lr: float,betas: tuple[float, float],num_epochs: int,num_steps_per_epoch: int,lam_input: float,lam_output: float,lam_depth: float,) -> None:super().__init__()self.save_hyperparameters() # saved in the dictionary self.hparams# disabling automatic optimization, as it willl be made manuallyself.automatic_optimization = Falseself.task_network = TaskNetwork(input_channels=3,  # RGB Channelsnum_classes=num_classes,  # Classespretrained_backbone=pretrained_backbone,)self.fake_transformation = ImageTransformNetwork(input_channels=num_classes + 4,  # RGB Channels + Classes + Depthoutput_channels=3,  # RGB Channels)self.input_discriminator = Discriminator(input_channels=3,  # RGB Channels)self.output_discriminator = Discriminator(input_channels=num_classes + 1,  # Classes + Depth)self.depths_loss = nn.L1Loss()self.labels_loss = nn.CrossEntropyLoss()self.discriminator_loss = nn.BCELoss()self.miou_index = MulticlassJaccardIndex(num_classes)self.weight_init(pretrained_backbone=pretrained_backbone)def weight_init(self, pretrained_backbone: bool = False):for name, module in self.named_modules():if "task" in name and pretrained_backbone:continueif isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):module.weight.data.normal_(0, 0.001)if module.bias is not None:module.bias.data.zero_()

然后我们定义优化器和学习率调度器。 我们需要一个优化器来处理“生成器”的权重,即生成器和任务网络,以及另一个优化器来处理鉴别器的权重。 作为学习率调度程序,我们将使用 OneCycle 策略,该策略在训练的第一部分通过提高学习率和降低动量来“预热”网络,从而允许早期探索权重空间并找到更好的起点 观点。 然后,在最后部分,通过余弦退火策略降低学习率。

    def configure_optimizers(self,) -> tuple[list[torch.optim.Adam], list[torch.optim.lr_scheduler.OneCycleLR]]:params_g = itertools.chain(self.fake_transformation.parameters(),self.task_network.parameters(),)params_d = itertools.chain(self.input_discriminator.parameters(),self.output_discriminator.parameters(),)optimizer_g, lr_sched_g = self._optimizer_lr_scheduler(params_g)optimizer_d, lr_sched_d = self._optimizer_lr_scheduler(params_d)return [optimizer_g, optimizer_d], [lr_sched_g, lr_sched_d]def _optimizer_lr_scheduler(self,parameters: Iterator[torch.nn.Parameter],) -> tuple[torch.optim.Adam, torch.optim.lr_scheduler.OneCycleLR]:optimizer = torch.optim.Adam(parameters,lr=self.hparams["init_lr"],betas=self.hparams["betas"],)lr_sched = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=self.hparams["init_lr"],epochs=self.hparams["num_epochs"],steps_per_epoch=self.hparams["num_steps_per_epoch"],base_momentum=self.hparams["betas"][0],)return optimizer, lr_sched

训练步骤接收输入:

  • 从真实数据集中采样的一批真实图像
  • 一批合成图像,以及从合成数据集中采样的相应标签和深度图。

然后它执行 2 个操作:

  • 鉴别器的优化步骤,需要所有输入
  • 生成器的优化步骤,仅需要合成数据

步骤的顺序对于确保模型的收敛至关重要。 由于生成器更容易崩溃,我们应该让判别器“引导”训练路径。 这样,在生成器步骤中,鉴别器的工作会更好一点,为生成器留下“更好”的梯度。

    def training_step(self, batch: tuple[torch.Tensor, ...]) -> None:optimizer_g, optimizer_d = self.optimizers()  real_images, fake_images, labels, depths = batch# Update D network: minimize log(D(x)) + log(1 - D(G(z)))self.toggle_optimizer(optimizer_d)optimizer_d.zero_grad()self._discriminator_step(real_images, fake_images, labels, depths)optimizer_d.step()self.untoggle_optimizer(optimizer_d)# Update G network: maximize log(D(G(z))) and minimize task lossself.toggle_optimizer(optimizer_g)optimizer_g.zero_grad()self._generator_step(fake_images, labels, depths)optimizer_g.step()self.untoggle_optimizer(optimizer_g)

鉴别器步骤只是最小化鉴别器输出的二元交叉熵损失。 首先在真实批次上完成此操作,其中预期标签全部为 1,然后在合成批次上完成,其中预期标签全部为零。

    def _discriminator_step(self,real_images: torch.Tensor,fake_images: torch.Tensor,labels: torch.Tensor,depths: torch.Tensor,) -> None:disc_lab = _labels(real_images, self.REAL_LABEL)disc_input = self.input_discriminator(real_images)disc_output = self.output_discriminator(torch.concat(self.task_network(real_images), dim=1))loss_input = (self.discriminator_loss(disc_input, disc_lab)* self.hparams["lam_input"])loss_output = (self.discriminator_loss(disc_output, disc_lab)* self.hparams["lam_output"])self.manual_backward(loss_input + loss_output)transformed = self.fake_transformation(fake_images, labels, depths)disc_lab = _labels(transformed, self.FAKE_LABEL)disc_input = self.input_discriminator(transformed)disc_output = self.output_discriminator(torch.concat(self.task_network(transformed), dim=1))loss_input = (self.discriminator_loss(disc_input, disc_lab)* self.hparams["lam_input"])loss_output = (self.discriminator_loss(disc_output, disc_lab)* self.hparams["lam_output"])self.manual_backward(loss_input + loss_output)# Log losses and metrics# ...

相反,生成器步骤最小化标签的交叉熵损失和深度估计的 L1Loss,并且还最大化鉴别器的二元交叉熵损失。 这是通过使用与之前相反的标签计算损失来完成的,因此所有标签都用于合成输入。 没有必要计算实际输入的损失,因为生成器的权重对此输出没有影响。

    def _generator_step(self,fake_images: torch.Tensor,labels: torch.Tensor,depths: torch.Tensor,) -> None:# Set disc_lab = REAL in order to maximize the loss for the # discriminator when inputs are all fakesdisc_lab = _labels(fake_images, self.REAL_LABEL)# Forward pass on all the networks to collect gradients for Gtransformed = self.fake_transformation(fake_images, labels, depths)fake_mask, fake_depth = self.task_network(transformed)disc_input = self.input_discriminator(transformed)disc_output = self.output_discriminator(torch.concat((fake_mask, fake_depth), dim=1))# Calculate lossesloss_input = (self.discriminator_loss(disc_input, disc_lab)* self.hparams["lam_input"])loss_output = (self.discriminator_loss(disc_output, disc_lab)* self.hparams["lam_output"])loss_depths = (self.depths_loss(fake_depth, depths) * self.hparams["lam_depth"])loss_labels = self.labels_loss(fake_mask, labels)# Calculate Gradientsself.manual_backward(loss_input + loss_output + loss_depths + loss_labels)# Log losses and metrics# ...

5、结束语

事实证明,这里解释的方法在各种数据集上都非常有效。 在下图中,我们可以看到,利用 sintetic 数据训练的模型优于仅在小型 KITTI 数据集上训练的模型。 从大量合成数据中获取的知识使模型能够从真实图像中提取更细粒度的细节。

在这里插入图片描述

图 1.3 — KITTI 数据集上的语义分割定性结果。 从左到右:左:输入图像,中:非自适应结果,右:GIO-Ada 方法的结果。

该算法也有一些缺点。 首先,对抗性训练可能非常不稳定,这可以从之前看到的不寻常的训练步骤中猜测出来。 因此,详尽的超参数搜索对于获得良好结果至关重要。 另一个主要问题是训练生成网络是一项内存非常密集的工作,尤其是对于高分辨率图像。

最新的研究集中在其他方法(例如自学习)上,利用变压器层中注意力机制的强泛化特性以及特定领域的数据增强技术。

尽管如此,生成方法(例如本文中讨论的生成方法)由于易于适应新领域以及生成学习研究的不断发展,继续在该领域占据一席之地。


原文链接:用合成数据进行语义分割 — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/180809.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FMC+DAM驱动LVGL刷屏

前提条件 使用FMC驱动LCD刷屏 LVGL移植 开启DMA 需要开启MEMTOMEMDMA。 开启MPU 有MPU时需要 使能I-cache D-cache时 使用DMA传输数据时要保证数据的完整行和准确性 修改代码 逻辑 等待DMA传输完成然后再刷屏。 修改 在DMA初始化函数中最后添加 注册DMA传输完成调用函…

k8s:二进制搭建 Kubernetes v1.20

目录 1 操作系统初始化配置 2 部署 etcd 集群 2.1 准备签发证书环境 2.2 生成Etcd证书 3 部署 docker引擎 4 部署 Master 组件 5 部署 Worker Node 组件 k8s集群master01:192.168.30.105 kube-apiserver kube-controller-manager kube-scheduler etcd k8s集…

05-流媒体-RTMP协议介绍

05-流媒体-RTMP协议介绍 1.RTMP概述 RTMP是一种常见的流媒体协议,是基于TCP/IP协议模型的应用层协议,工作在TCP协议上,端口是1935。通过TCP三次握手保证传输的可靠。 2.握手 2.1 握手过程 RTMP基于传输层TCP连接后,通过RTMP握手协议来完成RTMP连接。客户端和服务端各发…

将Series中每个值v替换为v在Series中升序排列时的位置值s.rank()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 将Series中每个值v 替换为v在Series中 升序排列时的位置值 s.rank() 选择题 下列代码执行三次排名索引a的名次值分别为? import pandas as pd s pd.Series([3,2,0,3],index list…

分享96个工作总结PPT,总有一款适合您

分享96个工作总结PPT,总有一款适合您 工作总结PPT下载链接:https://pan.baidu.com/s/18AriuVGxwINHrmgndX74dw?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不易…

【kafka】记一次kafka基于linux的原生命令的使用

环境是linux,4台机器,版本3.6,kafka安装在node 1 2 3 上,zookeeper安装在node2 3 4上。 安装好kafka,进入bin目录,可以看到有很多sh文件,是我们执行命令的基础。 启动kafka,下面的…

python 深度学习 解决遇到的报错问题9

本篇继python 深度学习 解决遇到的报错问题8-CSDN博客 目录 一、can only concatenate str (not "int") to str 二、cant convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, in…

74HC138逻辑芯片

文章目录 74系列逻辑芯片——74HC138基础信息描述特征应用范围 功能信息封装引脚基本电路 扩展性能分析 74系列逻辑芯片——74HC138 基础信息 描述 74HC138器件设计用于需要极短传播延迟时间的高性能存储器解码或数据路由应用;在高性能存储系统中,可使用…

二叉树采用二叉链表存储:编写计算二叉树最大宽度的算法(二叉树的最大宽度是指二叉树所有层中结点个数的最大值)

二叉树采用二叉链表存储:编写计算二叉树最大宽度的算法 (二叉树的最大宽度是指二叉树所有层中结点个数的最大值) 和二叉树有关的代码,基本都逃不过“先中后层”,这四种遍历 而我们这里是让你计算最大宽度&#xff0c…

Hydra post登录框爆破

文章目录 无token时的Hydra post登录框爆破带Token时的Hydra post登录框爆破 无token时的Hydra post登录框爆破 登录一个无验证码和token的页面,同时抓包拦截 取出发送数据包:usernameadb&password133&submitLogin 将用户名和密码替换 userna…

【2024最新】PE工具箱【下载安装】零基础到大神【附下载链接】

下载链接:点这里 1.PE (Portable Executable) 工具箱通常用于处理Windows可执行文件和动态链接库(DLL)的二进制文件格式。这些工具对于进行逆向工程、软件分析和系统维护等任务非常有用。以下是PE工具箱的一些常见功能和用法: 查…

【C语言进阶】之动态内存管理笔试题及柔性数组

【C语言进阶】之动态内存管理笔试题 1.动态内存管理笔试题汇总1.1第一道题1.2第二道题1.3第三道题1.4第四道题 2.C/C内存管理3.柔性数组3.1什么是柔性数组3.2柔性数组的使用3.2柔性数组的优点 📃博客主页: 小镇敲码人 🚀 欢迎关注&#xff1a…

Webpack搭建本地服务器

一、搭建webpack本地服务 1.为什么要搭建本地服务器? 目前我们开发的代码,为了运行需要有两个操作: 操作一:npm run build,编译相关的代码;操作二:通过live server或者直接通过浏览器&#x…

Draft-P802.11be-D3.2协议学习__$9-Frame-Format__$9.3.1.22-Trigger-frame-format

Draft-P802.11be-D3.2协议学习__$9-Frame-Format__$9.3.1.22-Trigger-frame-format 9.3.1.22.1 Genreal9.3.1.22.2 Common Info field9.3.1.22.3 Special User Info field9.3.1.22.4 HE variant User Info field9.3.1.22.5 EHT variant User Info field9.3.1.22.6 Basic Trigge…

4K Video Downloader Pro v4.28.0(视频下载器)

4K Video Downloader Pro是一款专业的视频下载软件,支持从YouTube、Vimeo、Facebook、Instagram、TikTok等主流视频网站下载高质量的4K、HD和普通视频。它的操作流程简单,只需复制视频链接并粘贴到软件中即可开始下载。此外,该软件还提供了多…

C# Winform串口助手

界面设置 修改控件name属性 了解SerialPort类 实现串口的初始化,开关 创建虚拟串口 namespace 串口助手 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){//在设计页面已经预先…

在Linux上通过NTLM认证连接到AD服务器(未完结)

这篇文章目前还没有实现具体的功能,只实现了明文登录,因为我缺少一些数据,比如通过密码生成hash,以及通过challenge生成response,我不知道怎么实现,因此这篇文章也是一个交流的文章,希望大佬看见…

深入理解网络IO复用并发模型

本文主要介绍服务端对于网络并发模型以及Linux系统下常见的网络IO复用并发模型。文章内容一共分为两个部分。 第一部分主要介绍网络并发中的一些基本概念以及我们Linux下常见的原生IO复用系统调用(epoll/select)等。第二部分主要介绍并发场景下常见的网…

el-table树状表格末行合计

首先,由于我的表头是动态的,所以就稍微复杂一点 效果图 表头数据格式是这样的 表格的数据格式是这样的 然后用合并的方法,此处就需要递归去计算,根据props去匹配每一列的数据,然后加起来,关键代码 //合计处理getSummaries(param) {const { columns, data } param;const su…

树结构及其算法-二叉排序树

目录 树结构及其算法-二叉排序树 C代码 树结构及其算法-二叉排序树 事实上,二叉树是一种很好的排序应用模式,因为在建立二叉树的同时,数据已经经过初步的比较,并按照二叉树的建立规则来存放数据,规则如下&#xff1…