第12章 PyTorch图像分割代码框架-2

模型模块

本书的第5-9章重点介绍了各种2D3D的语义分割和实例分割网络模型,所以在模型模块中,我们需要做的事情就是将要实验的分割网络写在该目录下。有时候我们可能想尝试不同的分割网络结构,所以在该目录下可以存在多个想要实验的网络模型定义文件。对于PASCAL VOC这样的自然数据集,我们可能想实验Deeplab v3+PSPNetRefineNet等网络的训练效果。代码11-3给出了Deeplab v3+网络封装后的主体部分,完整网络搭建代码可参考本书配套代码对应章节。

代码11-3 Deeplab v3+网络的主体部分

# 定义Deeplab V3+类
class DeepLabHeadV3Plus(nn.Module):def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):super(DeepLabHeadV3Plus, self).__init__()self.project = nn.Sequential(nn.Conv2d(low_level_channels, 48, 1, bias=False),nn.BatchNorm2d(48),nn.ReLU(inplace=True),)# ASPPself.aspp = ASPP(in_channels, aspp_dilate)# classifier headself.classifier = nn.Sequential(nn.Conv2d(304, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, num_classes, 1))self._init_weight()# forward methoddef forward(self, feature):# print(feature['low_level'].shape)# print(feature['out'].shape)low_level_feature = self.project(feature['low_level'])output_feature = self.aspp(feature['out'])output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))# weight initilizedef _init_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)

对于复杂网络搭建,一般都是采用自下而上的搭建方法,先搭建底层组件,再逐步向上封装,对于本例中的Deeplab v3+,可以先分别搭建backbone骨干网络、ASPP和编解码结构,最后再进行封装。

工具函数模块

工具函数是为项目完成各项功能所自定义的辅助函数,可以统一定义在utils文件夹下,根据实际项目的不同,工具函数也各不相同。常用的工具函数包括各种损失函数的定义loss.py、训练可视化函数的定义visualize.py、用于记录训练日志的log.py等。代码11-4给出了一个关于Focal loss损失函数的定义,该损失函数作为工具函数可放在loss.py文件中。

代码11-4 工具函数示例:定义一个Focal loss

# 导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义一个Focal loss类
class FocalLoss(nn.Module):def __init__(self, alpha=1, gamma=2):super(FocalLoss, self).__init__()self.alpha = alphaself.gamma = gammadef forward(self, inputs, targets):# Compute cross-entropy lossce_loss = F.cross_entropy(inputs, targets, reduction='none')# Compute the focal losspt = torch.exp(-ce_loss)  focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_lossreturn focal_loss.mean()

配置模块

配置模块是为项目模型训练传入各种参数而进行设置的模块,比如训练数据所在目录、训练所需要的各种参数、训练过程是否需要可视化等。一般来说,我们有两种方式来对项目执行参数进行配置管理,一种是直接在主函数main.py中使用argparse库对参数进行配置,然后再命令行中进行传入;另一种则是单独定义一个config.py或者config.yaml文件来对所有参数进行统一配置。基于argparse库的参数配置管理简单示例如代码11-5所示。

代码11-5 argparser参数配置管理

# 导入argparse库
import argparse
# 创建参数管理器
parser = argparse.ArgumentParser()
# 涉及数据相关的参数管理
parser.add_argument("--data_root", type=str, default='./dataset',help="path to Dataset")
parser.add_argument("--save_root", type=str, default='./',help="path to save result")
parser.add_argument("--dataset", type=str, default='voc',choices=['voc', 'cityscapes', 'ade'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,help="num classes (default: None)")

在上述代码中,我们基于argparse给出了一小部分参数配置管理代码,涉及训练数据相关的部分参数,包括数据读取路径、存放路径、训练所用数据集、分割类别数量等。

主函数模块

主函数模块main.py是项目的启动模块,该模块将定义好的数据和模型模块进行组装,并结合损失函数、优化器、评估方法和可视化等组件,将config.py中配置好的项目参数传入,根据训练-验证的模式,执行图像分割项目模型训练和验证。代码11-6VOC数据集训练验证部分代码。

代码11-6 主函数模块中的训练迭代部分

# 初始化区间损失
interval_loss = 0
while True:  # 执行训练model.train()cur_epochs += 1for (images, labels) in train_loader:cur_itrs += 1images = images.to(device, dtype=torch.float32)labels = labels.to(device, dtype=torch.long)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()np_loss = loss.detach().cpu().numpy()interval_loss += np_lossif vis is not None:vis.vis_scalar('Loss', cur_itrs, np_loss)# 打印训练信息if (cur_itrs) % opts.print_interval == 0:pass# 保存模型if (cur_itrs) % opts.val_interval == 0:pass# 日志记录logger.info("Save the latest model to %s" % save_path_checkpoints)# 模型验证print("validation...")model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,ret_samples_ids=vis_sample_id)logger.info("Validation performance: %s", val_score)# 保存最优模型if val_score['mean_dice'] > best_score:  best_score = val_score['mean_dice']save_ckpt(os.path.join(save_path_checkpoints, 'best_%s_%s_os%d.pth' %(opts.model, opts.dataset, opts.output_stride)))logger.info("Save best-performance model so far to %s" % save_path_checkpoints)# 训练过程可视化if vis is not None:  vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])vis.vis_table("[Val] Class IoU", val_score['Class IoU'])for k, (img, target, lbl) in enumerate(ret_samples):img = (denorm(img) * 255).astype(np.uint8)target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)concat_img = np.concatenate((img, target, lbl), axis=2)  vis.vis_image('Sample %d' % k, concat_img)scheduler.step()

在代码11-6中,我们展示了一个图像分割项目主函数模块中最核心的训练和验证部分。在训练时,按照指定迭代次数保存模型和对训练过程进行可视化展示。图11-2为训练打印的部分信息。

29f5835d31863a3c0e12336eba35dd8d.png

11-2 VOC训练过程信息

11-3为基于visdom的训练过程可视化展示,包括当前训练配置参数信息,训练损失函数变化曲线、验证集全局准确率、mIoU和类别IoU等指标变化曲线图。

c829adc88b9de5b266170dd6b3b86385.png

11-3 Deeplab v3+训练过程可视化

11-4展示了两组训练过程中验证集的输入图像、标签图像和模型预测图像的对比图。可以看到,基于Deeplab v3+的分割模型在PASCAL VOC 2012上表现还不错。

ea3738bd115f8bb89add3ea4ab2e67b6.png

11-4 验证集模型效果图

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/184022.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Cloud - 通过 Gateway webflux 编程实现网关异常处理

一、webflux 编程实现网关异常处理 我们知道在某一个服务中出现异常,可以通过 ControllerAdvice ExceptionHandler 来统一异常处理,即使是在微服务架构中,我们也可以将上述统一异常处理放入到公共的微服务中,这样哪一个微服务需要…

【Hadoop】YARN容量调度器详解

🦄 个人主页——🎐开着拖拉机回家_Linux,Java基础学习,大数据运维-CSDN博客 🎐✨🍁 🪁🍁🪁🍁🪁🍁🪁🍁 🪁🍁&am…

【Vue.js】Vue3全局配置Axios并解决跨域请求问题

系列文章目录 文章目录 系列文章目录背景一、部署Axios1. npm 安装 axios2. 创建 request.js,创建axios实例3. 在main.js中全局注册axios4. 在页面中使用axios 二、后端解决跨域请求问题方法一 解决单Contoller跨域访问方法二 全局解决跨域问题 背景 对于前后端分离…

[架构之路-254/创业之路-85]:目标系统 - 横向管理 - 源头:信息系统战略规划的常用方法论,为软件工程的实施指明方向!!!

目录 总论: 一、数据处理阶段的方法论 1.1 企业信息系统规划法BSP 1.1.1 概述 1.1.2 原则 1.2 关键成功因素法CSF 1.2.1 概述 1.2.2 常见的企业成功的关键因素 1.3 战略集合转化法SST:把战略目标转化成信息的集合 二、管理信息系统阶段的方法论…

『MySQL快速上手』-④-表的操作

文章目录 1.创建表2.查看表结构3.修改表4.删除表 1.创建表 语法格式如下: CREATE TABLE table_name ( field1 datatype, field2 datatype, field3 datatype ) character set 字符集 collate 校验规则 engine 存储引擎;说明: field 表示列名&#xff1…

大数据毕业设计选题推荐-营业厅营业效能监控平台-Hadoop-Spark-Hive

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

思维调试:调用ShellExecute后为什么程序没有启动

今天的问题来自我的一位读者: “如果我在命令行下启动我的程序,一切都是正常的。但是,当我在代码中调用 ShellExecuteEx 来启动程序时,好像什么都没有发生,这是为什么?” 在我问下面的第二个能给出答案的…

基于springboot实现致远汽车租赁平台管理系统项目【项目源码+论文说明】

基于springboot实现致远汽车租赁平台系统演示 摘要 首先,论文一开始便是清楚的论述了系统的研究内容。其次,剖析系统需求分析,弄明白“做什么”,分析包括业务分析和业务流程的分析以及用例分析,更进一步明确系统的需求。然后在明白了系统的需求基础上需要进一步地设计系统,主要…

视频特效编辑软件 After Effects 2022 mac中文版介绍 (ae 2022)

After Effects 2022 mac是一款视频特效编辑软件,被称为AE,拥有强大的特效工具,旋转,用于2D和3D合成、动画制作和视觉特效等,效果创建电影级影片字幕、片头和过渡,是一款可以帮助您高效且精确地创建无数种引…

django REST框架- Django-ninja

Django 是我学习的最早的web框架,大概在2014年,当时选他原因也很简单就是网上资料比较丰富,自然是遇到问题更容易找答案,直到 2018年真正开始拿django做项目,才对他有了更全面的了解。他是一个入门有门槛,学…

基于Java+SpringBoot+Mybaties-plus+Vue+ElementUI 失物招领小程序 设计与实现

一.项目介绍 失物招领小程序 用户登录、忘记密码、退出系统 发布失物 和 发布招领 查看我发布的失物和招领信息 失捡物品模块可以查看和搜索所有用户发布的信息。 二.环境需要 1.运行环境:java jdk1.8 2.ide环境:IDEA、Eclipse、Myeclipse都可以&#…

【Proteus仿真】【51单片机】水质监测报警系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器,使用按键、LED、蜂鸣器、LCD1602、PCF8591 ADC、PH传感器、浑浊度传感器、DS18B20温度传感器、继电器模块等。 主要功能: 系统运行后&…

如何使用 Loadgen 来简化 HTTP API 请求的集成测试

引言 在编写 HTTP 服务的过程中,集成测试 1 是保证程序正确性的重要一环,如下图所示,其基本的流程就是不断向服务发起请求然后校验响应的状态和数据等: 为大量的 API 和用例编写测试是一件繁琐的工作,而 Loadgen 2 正…

新版onenet平台安全鉴权的确定与使用

根据onenet官方更新的文档:平台提供开放的API接口,用户可以通过HTTP/HTTPS调用,进行设备管理,数据查询,设备命令交互等操作,在API的基础上,根据自己的个性化需求搭建上层应用。 为提高API访问安…

draw.io与项目管理——如何利用流程图工具提高项目管理效率

draw.io 是一款强大的图形绘制工具,用于创建各种类型的图表、流程图、组织结构图、网络图和平面设计等。它提供了丰富的绘图工具和预定义的图形库,使用户能够轻松创建专业水平的图形作品。 draw.io具有直观的界面和简单易用的功能,适合各种用…

C++:STL第一篇vector

目录 1.vector 的介绍及使用 1.1 vector的介绍 1.2 vector的使用 1.2.1 vector的定义 1.2.2 vector iterator (迭代器)的使用 1.2.3 vector空间增长问题 1.2.4 vector的增删改查 1.2.5 vector 迭代器失效问题。(重点) 2.vector 深度刨析及模拟实…

零代码编程:用ChatGPT批量提取flash动画swf文件中的mp3

文件夹:C:\迅雷下载\有声绘本_flash[淘宝-珍奥下载]\有声绘本 flash,里面有多个flash文件,怎么转换成mp3文件呢? 可以使用swfextract工具从Flash动画中提取音频,下载地址是http://www.swftools.org/download.html,也…

学习Opencv(蝴蝶书/C++)相关——2.用clang++或g++命令行编译程序

文章目录 1. c/cpp程序的执行1.1 cpp程序的编译过程1.2 预处理指令1.3 编译过程的细节2. macOS下使用Clang看cpp程序的编译过程2.1 示例2.1.1 第一步 预处理器-preprocessor2.1.2 第二步 编译器-compiler2.1.3 第三步 汇编器-assembler2.1.4 第四步 链接器-linker2.1.5 链接其他…

两数之和(哈希解法)

题目描述 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你可以按任…

Redis中的Zset类型

目录 Zset的相关命令 zadd zrange zcard zcount zrevrange zrangebyscore zpopmax bzpopmax zpopmin和bzpopmin zrank zrevrank zscore zrem zremrangebyrank zremrangebyscore 操作集合间的命令 zinterstore和zunionstore 内部编码 Zset的应用场景 Zset表…