MaskFormer语义分割算法测试

MaskFormer是一套基于transformer结构的语义分割代码。

链接地址:

https://github.com/facebookresearch/MaskFormer/tree/main

测试用的数据集:ADE20k Dataset

MIT Scene Parsing Benchmark

 该数据集可通过上述链接下载,其中training含有20210张图片,validation含有2000张图片。SceneParsing中是全景分割的标签图片,InstanceSegmentation是实例分割的标签图片。

1.环境搭建

本人在python3.10,CUDA11.8,torch2.1.0的linux服务器上做实验。通过pip装好torch之后,然后按照INSTALL.md中的提示安装Detectron中的包。

有以下几点需要注意:

1.需要安装opencv-python-headless版本的opnecv

pip install opencv-python-headless

2.需要安装1.*版本的numpy

pip install numpy==1.26.0

3.使用timm加载模型的时候,会遇到某些层不支持的问题,在mask_former/modeling/backbone/swin.py中,修改为如下:

# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.layers import DropPath, to_2tuple, trunc_normal_

4.安装panopticapi的包

git clone https://github.com/cocodataset/panopticapi.git
python setup.py build_ext --inplace
python setup.py build_ext install

个人配好的环境如下所示:

Package                 Version            Editable project location
----------------------- ------------------ ------------------------------------
absl-py                 2.2.1
antlr4-python3-runtime  4.9.3
black                   25.1.0
certifi                 2025.1.31
charset-normalizer      3.4.1
click                   8.1.8
cloudpickle             3.1.1
coloredlogs             15.0.1
contourpy               1.3.1
cycler                  0.12.1
Cython                  3.0.12
detectron2              0.6                /home/shengpeng/downloads/detectron2
filelock                3.18.0
flatbuffers             25.2.10
fonttools               4.56.0
fsspec                  2025.3.0
fvcore                  0.1.5.post20221221
grpcio                  1.71.0
h5py                    3.13.0
huggingface-hub         0.29.3
humanfriendly           10.0
hydra-core              1.3.2
idna                    3.10
iopath                  0.1.9
Jinja2                  3.1.6
kiwisolver              1.4.8
Markdown                3.7
markdown-it-py          3.0.0
MarkupSafe              3.0.2
matplotlib              3.10.1
mdurl                   0.1.2
mpmath                  1.3.0
mypy-extensions         1.0.0
networkx                3.4.2
numpy                   1.26.0
omegaconf               2.3.0
onnx                    1.17.0
onnx-simplifier         0.4.36
onnxruntime             1.21.0
opencv-python-headless  4.11.0.86
packaging               24.2
panopticapi             0.1
pathspec                0.12.1
pillow                  11.1.0
pip                     25.0
platformdirs            4.3.7
portalocker             3.1.1
protobuf                6.30.2
pycocotools             2.0.8
Pygments                2.19.1
pyparsing               3.2.3
python-dateutil         2.9.0.post0
PyYAML                  6.0.2
requests                2.32.3
rich                    13.9.4
safetensors             0.5.3
scipy                   1.15.2
setuptools              75.8.0
shapely                 2.0.7
six                     1.17.0
sympy                   1.13.3
tabulate                0.9.0
tensorboard             2.19.0
tensorboard-data-server 0.7.2
termcolor               2.5.0
timm                    1.0.15
tomli                   2.2.1
torch                   2.1.0+cu118
torchvision             0.16.0+cu118
tqdm                    4.67.1
triton                  2.1.0
typing_extensions       4.13.0
urllib3                 2.3.0
Werkzeug                3.1.3
wheel                   0.45.1
yacs                    0.1.8

下载预训练模型,即调用demo/demo.py,指定config的配置文件,和预训练权重,对图片进行推理,看预测效果。

python demo/demo.py \
--config-file configs/ade20k-150/maskformer_R50_bs16_160k.yaml \
--input images/ADE/ADE_test_00000001.jpg \
--opts MODEL.WEIGHTS weights/MaskFormer_seg_R50_512x512.pkl

 训练的脚本:

python train_net.py \
--num-gpus 2 \
--config-file configs/ade20k-150/maskformer_R50_bs16_160k.yaml \

在train_net.py中需要指定数据集的路径:

    os.environ['DETECTRON2_DATASETS']='/home/shengpeng/code/github_proj2/ADE2016/SceneParsing'

2张RTX3090的卡,大概跑了一晚上,结果如下:

其中最小模型,基于R50的backbone练出来的模型也有160多M。

2.torch模型转onnx

该套代码中没有带转onnx的代码,需要自己想办法转。

找到下载的detectron2的代码,detectron2/detectron2/engine/defaults.py中,重写class DefaultPredictor的__call__函数,如下:

    def __call__(self, original_image):with torch.no_grad(): image = original_image[:, :, ::-1]input_blob = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))input_blob = input_blob.unsqueeze(0)# print('self.cfg.MODEL.DEVICE:', self.cfg.MODEL.DEVICE)pixel_mean = self.cfg.MODEL.PIXEL_MEANpixel_std = self.cfg.MODEL.PIXEL_STDpixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)input_blob = (input_blob-pixel_mean) / pixel_stdinput_blob = input_blob.to(self.cfg.MODEL.DEVICE)print('input_blob.shape:',input_blob.shape)predictions = self.model(input_blob)[0]return predictions

重写MaskFormer/maskformer/mask_former_model.py中的class MaskFormer的forward()的函数:

    def forward(self, input_blob):print('MaskFormer input_blob:', input_blob.shape)print('self.device:', self.device)print('input_blob.device:', input_blob.device)input_h, input_w = input_blob.shape[2], input_blob.shape[3]features = self.backbone(input_blob)outputs = self.sem_seg_head(features)if self.training:# # mask classification target# if "instances" in batched_inputs[0]:#     gt_instances = [x["instances"].to(self.device) for x in batched_inputs]#     targets = self.prepare_targets(gt_instances, images)# else:#     targets = Nonetargets = None# bipartite matching-based losslosses = self.criterion(outputs, targets)for k in list(losses.keys()):if k in self.criterion.weight_dict:losses[k] *= self.criterion.weight_dict[k]else:# remove this loss if not specified in `weight_dict`losses.pop(k)return losseselse:mask_cls_results = outputs["pred_logits"]mask_pred_results = outputs["pred_masks"]# return mask_cls_results, mask_pred_results# upsample masksmask_pred_results = F.interpolate(mask_pred_results,size=(input_h, input_w),mode="bilinear",align_corners=False,)# mask_cls_result=mask_cls_results[0]# mask_pred_result=mask_pred_results[0]# print('mask_cls_result:',mask_cls_result.shape)# print('mask_pred_result:',mask_pred_result.shape)print('mask_cls_results:',mask_cls_results.shape)print('mask_pred_results:',mask_pred_results.shape)processed_results = []            if self.sem_seg_postprocess_before_inference:mask_pred_results = sem_seg_postprocess(mask_pred_results, [input_h, input_w], input_h, input_w)# semantic segmentation inferencer = self.semantic_inference(mask_cls_results, mask_pred_results)print(f'r1:{r.shape}')if not self.sem_seg_postprocess_before_inference:r = sem_seg_postprocess(r, [input_h, input_w], input_h, input_w)print(f'r2:{r.shape}')processed_results.append({"sem_seg": r})print('processed_results num:',len(processed_results))return processed_results

在tools中新建convert_torchvision_to_onnx.py的转模型脚本:

import argparse
import glob
import multiprocessing as mp
import os# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: onimport tempfile
import time
import warningsimport cv2
import numpy as np
import tqdmfrom detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_loggerfrom mask_former import add_mask_former_config
from demo.predictor import VisualizationDemoimport onnx
import torchdef setup_cfg(args):# load config from file and command-line argumentscfg = get_cfg()add_deeplab_config(cfg)add_mask_former_config(cfg)cfg.merge_from_file(args.config_file)cfg.merge_from_list(args.opts)cfg.freeze()return cfgdef get_parser():parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")parser.add_argument("--config-file", default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml")parser.add_argument("--input", nargs="+")parser.add_argument("--output", help="A file or directory to save output visualizations. ""If not given, will show output in an OpenCV window.")parser.add_argument("--confidence-threshold", type=float, default=0.5, help="Minimum score for instance predictions to be shown")parser.add_argument("--opts",help="Modify config options using the command-line 'KEY VALUE' pairs",default=['MODEL.WEIGHTS', 'output/model_0159999.pth'],nargs=argparse.REMAINDER,)return parserif __name__ == "__main__":args = get_parser().parse_args()cfg = setup_cfg(args)demo = VisualizationDemo(cfg)net = demo.predictor.modelnet.to('cpu')input_model_path=cfg.MODEL.WEIGHTSprint('input_model_path:%s' % (input_model_path))output_model_path=input_model_path.replace('.pth', '.onnx')im = torch.zeros(1, 3, 512, 512).to('cpu')  # image size(1, 3, 512, 512) BCHWinput_layer_names   = ["images"]output_layer_names  = ["output"]dynamic = False# Export the modelprint(f'Starting export with onnx {onnx.__version__}.')torch.onnx.export(net,im,f               = output_model_path,verbose         = False,opset_version   = 12,training        = torch.onnx.TrainingMode.EVAL,do_constant_folding = True,input_names     = input_layer_names,output_names    = output_layer_names,dynamic_axes    = {'images': {0: 'batch'},'output': {0: 'batch'}} if dynamic else None)# Checksmodel_onnx = onnx.load(output_model_path)  # load onnx modelonnx.checker.check_model(model_onnx)  # check onnx model# Simplify onnxsimplify = 1if simplify:import onnxsimprint(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')# model_onnx, check = onnxsim.simplify(#     model_onnx,#     dynamic_input_shape=False,#     input_shapes=None)onnx_sim_model, check = onnxsim.simplify(model_onnx)assert check, 'assert check failed'onnx.save(model_onnx, output_model_path)print('Onnx model save as {}'.format(output_model_path))

即可转换成功得到对应的onnx模型,可使用onnxruntime加载该onnx模型做推理。

 

3.推理速度测试

在c++代码中,加载onnx转tensorrt测试速度,对比segformer中14M的模型,和该MaskFormer161M的模型,同时基于512x512的分辨率,转fp16的engine,做推理:

segfomer_b0      10ms左右

maskformer_R50     220ms左右

这个实验结果显示,该maskformer的模型不适用于那种速度要求特别高的场景,更适用于类别数较多,全景分割的场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/42531.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javaWeb vue的简单语法

一、简介 两大核心优势: 声明式渲染:Vue 基于标准 HTML 拓展了一套模板语法,使得我们可以声明式地描述最终输出的 HTML 和 JavaScript 状态之间的关系。 响应性:Vue 会自动跟踪 JavaScript 状态并在其发生变化时响应式地更新 D…

vue create创建 Vue-router 工程

vue create创建 Vue-router 工程 参考 创建vue项目的两种方式:vue-create与vite https://www.cnblogs.com/reverse-x/p/16806534.html Vue2 脚手架 创建工程 测试程序 https://blog.csdn.net/wowocpp/article/details/146590400 在 上面的基础上 cd .\vue2-demo\…

CXL UIO Direct P2P学习

前言: 在CXL协议中,UIO(Unordered Input/Output) 是一种支持设备间直接通信(Peer-to-Peer, P2P)的机制,旨在绕过主机CPU或内存的干预,降低延迟并提升效率。以下是UIO的核心概念及UI…

口腔种植全流程AI导航系统及辅助诊疗与耗材智能化编程分析

一、系统架构与编程框架设计 口腔种植全流程人工智能导航系统的开发是一项高度复杂的多学科融合工程,其核心架构需在医学精准性、工程实时性与临床实用性之间实现平衡。系统设计以模块化分层架构为基础,结合高实时性数据流与多模态协同控制理念,覆盖从数据采集、智能决策到…

李宏毅机器学习笔记(1)—机器学习基本概念+深度学习基本概念

机器学习基本概念 1、获取模型 步骤 1.1、假定未知函数 带未知参数的函数 1.2、定义损失函数 真实值:label MAE MSE 几率分布,cross-entropy? 1.3、优化 单独考虑一个参数 让损失函数最小,找导数为零的点 单独考虑w,w…

专注自习室:番茄工作法实践

专注自习室:番茄工作法实践 我需要一个任务管理工具,但在网上找了很多都找不到合适的工具。市面上的大多数产品过于强调任务完成性,给我带来了很强的心理压力,这种压力最终反而降低了我的工作效率。于是我决定自己动手&#xff0…

【银河麒麟高级服务器操作系统 】虚拟机运行数据库存储异常现象分析及处理全流程

更多银河麒麟操作系统产品及技术讨论,欢迎加入银河麒麟操作系统官方论坛 https://forum.kylinos.cn 了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer…

阿里云数据学习20250327

课堂链接:阿里云培训中心 (aliyun.com) 一、课堂问题 (一)课时3 1.支持字符集的含义是什么

使用QuickReporter将多张图片插入在word多行的表格中

之前有一位QuickReporter的用户提到过一个需求。他有大量的图片需要插入在word里面,他的想法是将图片放在一个文件夹内,按编号1,2,3,...编号,然后自动将这些图片从前到后插入到表格中。 这次偶然发现了该需求是可以实现的,且在当…

【大模型】激活函数之SwiGLU详解

文章目录 1. Swish基本定义主要特点代码实现 2. GLU (Gated Linear Unit)基本定义主要特点代码实现 3. SwiGLU基本定义主要特点代码实现 参考资料 SWiGLU是大模型常用的激活函数,是2020年谷歌提出的激活函数,它结合了Swish和GLU两者的特点。SwiGLU激活函…

vs2017开启性能探测器失败

开启性能探测器失败 错误: 无法启用性能探测器服务没有及时响应启动或控制请求。 (HRESULT: 0xe1110002) Microsoft.DiagnosticsHub.Diagnostics.CollectionStartFailedHubException”的异常。 各种原因排查: 1.管理员启动 2.开启各种诊断服务&…

FPGA——分秒计数器设计(DE2-115开发板)

一、项目创建 1.创建工程 点击File->New Project Wizard...或者直接在页面处点击 在第一行选择文件存放地点,第二行为项目名称,第三行为顶级设计实体名称 (下面的步骤可以暂时不做直接点Finish,因为是先写代码先把它跑出来暂…

香蕉成熟度检测和识别1:香蕉成熟度数据集说明(含下载链接)

一. 前言 本篇博客是《香蕉成熟度检测和识别》系列文章之《香蕉成熟度数据集说明(含下载链接)》,网上有很多香蕉成熟度数据集的数据,百度一下,一搜一大堆,但质量参差不齐,很多不能用,即使一个一个的看也会…

⑦(ACG-网络配置)

网络配置是指对计算机网络的各种参数进行设置和调整,以实现网络正常运行和高效通信。网络配置包括多方面的内容,常见的配置包括: 1. IP地址设置:IP地址是设备在网络中的身份标识,设置IP地址是网络配置的基础&#xff…

DeepSeek反作弊技术方案全解析:AI如何重构数字信任体系

一、技术原理:构建智能防御矩阵 1.1 多维度行为分析引擎 DeepSeek 反作弊技术的基石是多维度行为分析引擎,其借助深度学习算法,对用户行为轨迹展开毫秒级的细致剖析。这一引擎能够构建起涵盖操作频率、设备指纹、网络环境等多达 128 个特征维度的精准行为画像。以教育场景为…

盈亏平衡分析

盈亏平衡分析是一种重要的管理分析方法,广泛应用于企业的成本控制、生产决策、定价策略等方面,以下是对它的详细阐述: 一、基本概念 定义:盈亏平衡分析是通过研究企业在一定时期内的成本、收入与利润之间的关系,确定…

Vue2 脚手架 创建工程 测试程序

Vue2 脚手架 创建工程 测试程序 创建一个 目录 H:\g_web_vue\test 打开 vscode H:\g_web_vue\test 新建文件夹 vue2-demo cd .\vue2-demo vue create demo1 键盘 向下箭头 按键,选中 Vue2, 然后 回车 cd demo1 npm run serve http://localhost:808…

Yolo_v8的安装测试

前言 如何安装Python版本的Yolo,有一段时间不用了,Yolo的版本也在不断地发展,所以重新安装了运行了一下,记录了下来,供参考。 一、搭建环境 1.1、创建Pycharm工程 首先创建好一个空白的工程,如下图&…

IP协议的介绍

网络层的主要功能是在复杂的网络环境中确定一个合适的路径.网络层的协议主要是IP协议.IP协议头格式如下: 1.4位版本号:指定IP协议的版本,常用的是IPV4,对于IPV4来说,这里的值就是4. 2.4位头部长度,单位也是4个字节,4bit表示的最大数字是15,因此IP头部的最大长度就是60字节 3.…

Linux环境上传本地文件安装mysql

windows下载本地文件包,找到文件所在目录 scp 文件名 root192.168.xx.xx:/opt输入ssh密码,成功上传到服务器! //docker拉取镜像 cd /opt && docker load -i 文件名docker run -it -d --restartalways --namemysql5 -p 3106:3306 -v …