Pytorch 基于 deeplabv3_resnet50 迁移训练自己的图像语义分割模型

一、图像语义分割

图像语义分割是计算机视觉领域的一项重要任务,旨在将图像中的每个像素分配到其所属的语义类别,从而实现对图像内容的细粒度理解。与目标检测不同,图像语义分割要求对图像中的每个像素进行分类,而不仅仅是确定物体的边界框。deeplabv3_resnet50 就是一个常用的语义分割模型,它巧妙地将两个强大的神经网络架构融合在一起,为像素级别的图像理解提供了强大的解决方案。

首先,DeepLabV3是一种专门设计用于语义分割的架构。通过采用扩张卷积(也称为空洞卷积)能够在不损失空间分辨率的情况下捕捉多尺度信息。这使得模型能够对图像进行精细的分割,识别并分类每个像素的语义信息。

其次,ResNet50ResNet系列中的一员,拥有50层深度的残差网络结构。通过引入残差连接,ResNet50解决了深层神经网络中梯度消失的问题,使得网络更易于训练。作为骨干网络,ResNet50提供了强大的特征提取能力,有助于捕捉图像中的高级语义特征。

本文基于 Pytorch 使用 deeplabv3_resnet50 迁移训练自己的图像语义分割模型,数据使用的数据集,最后效果如下所示:

在这里插入图片描述

下面使用的 torch 版本如下:

torch                   1.13.1+cu116
torchaudio              0.13.1+cu116
torchvision             0.14.1+cu116

二、数据集准备

图像数据可以从网上找一些或者自己拍摄,我这里准备了一些 的图片:

在这里插入图片描述

这里构建 VOC 格式数据集,因此需要新建如下结构目录:

VOCdevkitVOC2012AnnotationsImageSetsSegmentationJPEGImagesSegmentationClass

在这里插入图片描述

目录解释如下:

  • Annotations 存放标注后的 xml 文件
  • Segmentation 划分后的训练样本名称和验证集样本名称(只存放名称)
  • JPEGImages 存放收集的图像
  • SegmentationClass 存放语义分割的mask标签图像

将收集的图像放到 JPEGImages 目录下:

在这里插入图片描述

三、图像标注

标注工具使用 labelme ,如果没有安装,使用下面方式引入该依赖:

pip install labelme -i https://pypi.tuna.tsinghua.edu.cn/simple

然后控制台输入:labelme ,即可打开标注工具:

在这里插入图片描述

通过构建一个区域后,需要给该区域一个标签,这里给 cat

在这里插入图片描述

xml 文件保存在 Annotations 下:

在这里插入图片描述

四、生成 mask 标签图像及数据划分

标注完成后,需要将标注数据转为 mask 标签图像:

trans_mask.py

import json
import os
import os.path as osp
import copy
import numpy as np
import PIL.Imagefrom labelme import utilsNAME_LABEL_MAP = {'_background_': 0,"cat": 1,
}def main():annotations = './voc/VOCdevkit/VOC2012/Annotations'segmentationClass = './voc/VOCdevkit/VOC2012/SegmentationClass'list = os.listdir(annotations)for i in range(0, len(list)):path = os.path.join(annotations, list[i])filename = list[i][:-5]if os.path.isfile(path):data = json.load(open(path,encoding="utf-8"))img = utils.image.img_b64_to_arr(data['imageData'])lbl, lbl_names = utils.shape.labelme_shapes_to_label(img.shape, data['shapes'])  # labelme_shapes_to_label# modify labels according to NAME_LABEL_MAPlbl_tmp = copy.copy(lbl)for key_name in lbl_names:old_lbl_val = lbl_names[key_name]new_lbl_val = NAME_LABEL_MAP[key_name]lbl_tmp[lbl == old_lbl_val] = new_lbl_vallbl_names_tmp = {}for key_name in lbl_names:lbl_names_tmp[key_name] = NAME_LABEL_MAP[key_name]# Assign the new label to lbl and lbl_names dictlbl = np.array(lbl_tmp, dtype=np.int8)label_path = osp.join(segmentationClass, '{}.png'.format(filename))PIL.Image.fromarray(lbl.astype(np.uint8)).save(label_path)print('Saved to: %s' % label_path)if __name__ == '__main__':main()

注意修改路径为你的地址,运行后可以在 SegmentationClass 目录下看到 mask 标签图像:

在这里插入图片描述

下面进行数据的划分,这里划分为90%训练集和10%验证集:

split_data.py

import osif __name__ == '__main__':JPEGImages = "./voc/VOCdevkit/VOC2012/JPEGImages"Segmentation = "./voc/VOCdevkit/VOC2012/ImageSets/Segmentation"# 训练集比例 90%training_ratio = 0.9list = os.listdir(JPEGImages)all = len(list)print(all)train_count = int(all * training_ratio)train = list[0:train_count]val = list[train_count:]with open(os.path.join(Segmentation, "train.txt"), "w", encoding="utf-8") as f:for name in train:name = name.split(".")[0]f.write(name + "\n")f.flush()with open(os.path.join(Segmentation, "val.txt"), "w", encoding="utf-8") as f:for name in val:name = name.split(".")[0]f.write(name + "\n")f.flush()

运行后可以在 Segmentation 目录下看到两个文件:

在这里插入图片描述

到这里就已经准备好了 VOC 格式的数据集。

五、模型训练

deeplabv3_resnet50 的复现这里就不重复造轮子了,pytorch 官方的 vision 包已经做好了实现,拉取该工具包:

git clone https://github.com/pytorch/vision.git

可以在 references 下看到不同任务的实现:

在这里插入图片描述

这里我们主要关注 segmentation 中:

在这里插入图片描述

需要修改下 train.py 中的 voc 的分类数,由于我们只是分割出猫,加上背景就是 2 类:

在这里插入图片描述

控制台进入到该目录下,运行 train.py 文件开始训练:

python train.py --data-path ./voc --lr 0.02 --dataset voc --batch-size 2 --epochs 50 --model deeplabv3_resnet50 --device cuda:0 --output-dir model --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1

如果缺失部分依赖直接 pip 安装即可。

其中参数的解释如下:

  • data-path:上面我们构建的 VOC 数据集的地址。
  • lr:初始学习率。
  • dataset:数据集的格式,这里我们是 voc 格式。
  • batch-size:一个批次的大小,这里我 GPU显存有限设的 2 ,如果显存大可以调大一些。
  • epochs:训练多少个周期。
  • model:训练使用的模型,可选:fcn_resnet50、fcn_resnet101、deeplabv3_resnet50、deeplabv3_resnet101、deeplabv3_mobilenet_v3_large、lraspp_mobilenet_v3_large
  • device:训练使用的设备。
  • output-dir:训练模型输出目录。
  • aux-loss:启用 aux-loss
  • weights-backbonebackbone模型。

更多参数可以打开 train.py 文件查看:

在这里插入图片描述

训练过程:

在这里插入图片描述

这里我训练完后 loss=0.3766, mean IoU= 85.4

在这里插入图片描述

五、模型预测

import os
import torch
import torch.utils.data
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms# 转换输出,将每个标签换成对应的颜色
def decode_segmap(image, num_classes, label_colors):r = np.zeros_like(image).astype(np.uint8)g = np.zeros_like(image).astype(np.uint8)b = np.zeros_like(image).astype(np.uint8)for l in range(0, num_classes):idx = image == lr[idx] = label_colors[l, 0]g[idx] = label_colors[l, 1]b[idx] = label_colors[l, 2]rgb = np.stack([r, g, b], axis=2)return rgbdef main():# 基础模型base_model = "deeplabv3_resnet50"# 训练后的权重model_weights = "./model/model_49.pth"# 使用设备device = "cuda:0"# 预测图像目录地址prediction_path = "./voc/VOCdevkit/VOC2012/JPEGImages"# 分类数num_classes = 2# 标签对应的颜色,0: 背景,1:catlabel_colors = np.array([(0, 0, 0), (255, 255, 255)])device = torch.device(device)print("using {} device.".format(device))# 加载模型model = torchvision.models.get_model(base_model,num_classes=2,)assert os.path.exists(model_weights), "{} file dose not exist.".format(model_weights)model.load_state_dict(torch.load(model_weights, map_location=device)["model"], strict=False)print(model)model.to(device)model.eval()files = os.listdir(prediction_path)for file in files:filename = os.path.join(prediction_path, file)input_image = Image.open(filename).convert('RGB')preprocess = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])input_tensor = preprocess(input_image)input_batch = input_tensor.unsqueeze(0).to(device)with torch.no_grad():output = model(input_batch)['out'][0]output_predictions = output.argmax(0)out = output_predictions.detach().cpu().numpy()rgb = decode_segmap(out, num_classes, label_colors)plt.figure()plt.subplot(1, 2, 1)plt.imshow(input_image)plt.subplot(1, 2, 2)plt.imshow(rgb)plt.show()if __name__ == '__main__':main()

输出结果:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/206235.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图形数据库的实战应用:如何在 Neo4j 中有效管理复杂关系

关系数据库管理系统( RDBMS ) 代表了最先进的技术,这在一定程度上要归功于其由周边技术、工具和广泛的专业技能组成的完善的生态系统。 在这个涵盖信息技术(IT) 和运营技术(OT) 的技术革命时代,人们普遍认识到性能方面出现了重大挑战,特别是…

【广州华锐互动】Web3D云展编辑器能为展览行业带来哪些便利?

在数字时代中,传统的展览方式正在被全新的技术和工具所颠覆。其中,最具有革新意义的就是Web3D云展编辑器。这种编辑器以其强大的功能和灵活的应用,正在为展览设计带来革命性的变化。 广州华锐互动开发的Web3D云展编辑器是一种专门用于创建、编…

关于网站的favicon.ico图标的设置需要注意的几点

01-必须在网页的head标签中放上对icon图标的说明语句&#xff1a; 比如下面这样的语句&#xff1a; <link rel"shortcut icon" href"/favicon.ico">否则&#xff0c;浏览器虽然能读到图标&#xff0c;但是不会把图标显示在标签上。 02-为了和本地开…

DHCP、ARP、FTP、DNS、VRRP、STP、报文交互流程

目录 一、DHCP 1、DHCP终结 1、DHCP discover 2、DHCP offer 3、DHCP request 4、DHCP ack 5、DHCP request 6、DHCP 续租 2、DHCP终结 二、ARP 1、ARP类型 动态ARP 静态ARP ARP代理 ARP代理的分类&#xff1a;路由式代理、VLAN内的ARP代理、VLAN间的ARP代理。 6…

【Hadoop】分布式文件系统 HDFS

目录 一、介绍二、HDFS设计原理2.1 HDFS 架构2.2 数据复制复制的实现原理 三、HDFS的特点四、图解HDFS存储原理1. 写过程2. 读过程3. HDFS故障类型和其检测方法故障类型和其检测方法读写故障的处理DataNode 故障处理副本布局策略 一、介绍 HDFS &#xff08;Hadoop Distribute…

electron调用dll问题总汇

通过一天的调试安装&#xff0c;electron调用dll成功&#xff0c;先列出当前的环境&#xff1a;node版本: 18.12.0&#xff0c;32位的&#xff08;因为dll为32位的&#xff09; VS2019 python node-gyp 1、首先要查看报错原因&#xff0c;通常在某一行会有提示&#xff0c;常…

C#常见的设计模式-行为型模式

前言 行为型模式是面向对象设计中的一类设计模式&#xff0c;它关注对象之间的通信和相互作用&#xff0c;以实现特定的行为或功能。在C#中&#xff0c;有许多常见的行为型模式&#xff0c;下面将对其中10种行为型模式进行介绍&#xff0c;并给出相应的代码示例。 目录 前言1.…

什么是网络爬虫技术?它的重要用途有哪些?

网络爬虫&#xff08;Web Crawler&#xff09;是一种自动化的网页浏览程序&#xff0c;能够根据一定的规则和算法&#xff0c;从互联网上抓取和收集数据。网络爬虫技术是随着互联网的发展而逐渐成熟的一种技术&#xff0c;它在搜索引擎、数据挖掘、信息处理等领域发挥着越来越重…

线性分组码的奇偶校验矩阵均匀性分析

回顾信道编解码知识&#xff0c;我们知道信道编码要求编码具有检纠错能力&#xff0c;作为FEC&#xff08;forward error correction&#xff09;前向纠错编码的一类&#xff0c;线性分组码表示校验位与信息位的关系能够线性表示。 在这篇文章中&#xff0c;并不是要讨论信道编…

【古月居《ros入门21讲》学习笔记】09_订阅者Subscriber的编程实现

目录 说明&#xff1a; 1. 话题模型 图示 说明 2. 实现过程&#xff08;C&#xff09; 创建订阅者代码&#xff08;C&#xff09; 配置发布者代码编译规则 编译并运行 编译 运行 3. 实现过程&#xff08;Python&#xff09; 创建订阅者代码&#xff08;Python&…

MYSQL索引使用注意事项

索引使用注意事项&#xff1a; 1.索引列运算 不要在索引列上进行运算操作&#xff0c;否则索引将失效&#xff1b; 2.字符串不加引号 字符串类型使用时&#xff0c;不加引号&#xff0c;否则索引将失效&#xff1b; 3.模糊查询 如果仅仅是尾部模糊匹配&#xff0c;索引将不会失…

WSL中安装的Pycharm如何在Windows的开始菜单中新建图标?或WSL中的Pycharm经常花屏

WSL中安装的Pycharm如何在Windows的开始菜单中新建图标&#xff1f;或WSL中的Pycharm经常花屏 ⚙️1.软件环境⚙️&#x1f50d;2.问题描述&#x1f50d;&#x1f421;3.解决方法&#x1f421;&#x1f914;4.结果预览&#x1f914; ⚙️1.软件环境⚙️ Windows10 教育版64位 W…

【云栖 2023】姜伟华:Hologres Serverless 之路——揭秘弹性计算组

云布道师 本文根据 2023 云栖大会演讲实录整理而成&#xff0c;演讲信息如下&#xff1a; 演讲人&#xff1a;姜伟华 | 阿里云计算平台事业部资深技术专家、阿里云实时数仓 Hologres 研发负责人 演讲主题&#xff1a;Hologres Serverless 之路——揭秘弹性计算组 实时化成为…

牛客算法心得——abb(dp)

大家好&#xff0c;我是晴天学长&#xff0c;传智杯的题&#xff0c;我准备写一个题解&#xff0c;需要的小伙伴可以关注支持一下哦&#xff01;后续会继续更新的。&#x1f4aa;&#x1f4aa;&#x1f4aa; 1) .abb leafee 最近爱上了 abb 型语句&#xff0c;比如“叠词词”、…

【物联网与大数据应用】Hadoop数据处理

Hadoop是目前最成熟的大数据处理技术。Hadoop利用分而治之的思想为大数据提供了一整套解决方案&#xff0c;如分布式文件系统HDFS、分布式计算框架MapReduce、NoSQL数据库HBase、数据仓库工具Hive等。 Hadoop的两个核心解决了数据存储问题&#xff08;HDFS分布式文件系统&#…

【Java学习笔记】75 - 算法优化入门 - 马踏棋盘问题

一、意义 1.算法是程序的灵魂&#xff0c;为什么有些程序可以在海量数据计算时&#xff0c;依然保持高速计算? 2.拿老韩实际工作经历来说&#xff0c;在Unix下开发服务器程序&#xff0c;功能是要支持上千万人同时在线&#xff0c;在上线前&#xff0c; 做内测&#xff0c;一…

常用服务注册中心与发现(Eurake、zookeeper、Nacos)笔记(一)基础概念

基础概念 注册中心 在服务治理框架中&#xff0c;通常都会构建一个注册中心&#xff0c;每个服务单元向注册中心登记自己提供的服务&#xff0c;将主机与端口号、版本号、通信协议等一些附加信息告知注册中心&#xff0c;注册中心按照服务名分类组织服务清单&#xff0c;服务…

OpenGL之Mesa3D编译for Ubuntu20.04(三十六)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只有行动才是治疗恐惧和懒惰的唯一良药. 更多原创,欢迎关注:Android…

vue3中的Fragment、Teleport、Suspense新组件

Fragment组件 在Vue2中: 组件必须有一个根标签 在Vue3中: 组件可以没有根标签, 内部会将多个标签包含在一个Fragment虚拟元素中 好处: 减少标签层级, 减小内存占用 <template><div style"font-size: 14px;"><p> 组件可以没有根标签</p&g…

大数据技术之数据安全与网络安全——CMS靶场(文章管理系统)实训

大数据技术之数据安全与网络安全——CMS靶场(文章管理系统)实训 在当今数字化时代&#xff0c;大数据技术的迅猛发展带来了前所未有的数据增长&#xff0c;同时也催生了对数据安全和网络安全的更为迫切的需求。本篇博客将聚焦于大数据技术背景下的数据安全与网络安全&#xff…