【深度学习项目】语义分割-FCN网络(原理、网络架构、基于Pytorch实现FCN网络)

文章目录

  • 介绍
    • 深度学习语义分割的关键特点
    • 主要架构和技术
    • 数据集和评价指标
    • 总结
  • FCN网络
    • FCN 的特点
    • FCN 的工作原理
    • FCN 的变体和发展
    • FCN 的网络结构
    • FCN 的实现(基于Pytorch)
      • 1. 环境配置
      • 2. 文件结构
      • 3. 预训练权重下载地址
      • 4. 数据集,本例程使用的是PASCAL VOC2012数据集
      • 5. 训练方法
      • 6. 注意事项
      • 7. Pytorch官方实现的FCN网络框架图
      • 8. 完整代码
        • 8.1 src文件目录代码
        • 8.2 train_utils文件目录代码
        • 8.3 根目录代码

个人主页:道友老李
欢迎加入社区:道友老李的学习社区

介绍

深度学习语义分割(Semantic Segmentation)是一种计算机视觉任务,它旨在将图像中的每个像素分类为预定义类别之一。与物体检测不同,后者通常只识别和定位图像中的目标对象边界框,语义分割要求对图像的每一个像素进行分类,以实现更精细的理解。这项技术在自动驾驶、医学影像分析、机器人视觉等领域有着广泛的应用。

深度学习语义分割的关键特点

  • 像素级分类:对于输入图像的每一个像素点,模型都需要预测其属于哪个类别。
  • 全局上下文理解:为了正确地分割复杂场景,模型需要考虑整个图像的内容及其上下文信息。
  • 多尺度处理:由于目标可能出现在不同的尺度上,有效的语义分割方法通常会处理多种分辨率下的特征。

主要架构和技术

  1. 全卷积网络 (FCN)

    • FCN是最早的端到端训练的语义分割模型之一,它移除了传统CNN中的全连接层,并用卷积层替代,从而能够接受任意大小的输入并输出相同空间维度的概率图。
  2. 跳跃连接 (Skip Connections)

    • 为了更好地保留原始图像的空间细节,一些模型引入了跳跃连接,即从编码器部分直接传递特征到解码器部分,这有助于恢复细粒度的结构信息。
  3. U-Net

    • U-Net是一个专为生物医学图像分割设计的网络架构,它使用了对称的收缩路径(下采样)和扩展路径(上采样),以及丰富的跳跃连接来捕捉局部和全局信息。
  4. DeepLab系列

    • DeepLab采用了空洞/膨胀卷积(Atrous Convolution)来增加感受野而不减少特征图分辨率,并通过多尺度推理和ASPP模块(Atrous Spatial Pyramid Pooling)增强了对不同尺度物体的捕捉能力。
  5. PSPNet (Pyramid Scene Parsing Network)

    • PSPNet利用金字塔池化机制收集不同尺度的上下文信息,然后将其融合用于最终的预测。
  6. RefineNet

    • RefineNet强调了高分辨率特征的重要性,并通过一系列细化单元逐步恢复细节,确保输出高质量的分割结果。
  7. HRNet (High-Resolution Network)

    • HRNet在整个网络中保持了高分辨率的表示,同时通过多尺度融合策略有效地整合了低分辨率但富含语义的信息。

数据集和评价指标

常用的语义分割数据集包括PASCAL VOC、COCO、Cityscapes等。这些数据集提供了标注好的图像,用于训练和评估模型性能。

评价语义分割模型的标准通常包括:

  • 像素准确率 (Pixel Accuracy):所有正确分类的像素占总像素的比例。
  • 平均交并比 (Mean Intersection over Union, mIoU):这是最常用的评价指标之一,计算每个类别的IoU(交集除以并集),然后取平均值。
  • 频率加权交并比 (Frequency Weighted IoU):考虑每个类别的出现频率,对mIoU进行加权。

总结

随着硬件性能的提升和算法的进步,深度学习语义分割已经取得了显著的进展。现代模型不仅能在速度上满足实时应用的需求,还能提供非常精确的分割结果。未来的研究可能会集中在提高模型效率、增强跨域泛化能力以及探索无监督或弱监督的学习方法等方面。

FCN网络

FCN(Fully Convolutional Networks,全卷积网络)是一种用于计算机视觉任务的神经网络架构,尤其擅长处理像素级别的分类问题,例如语义分割。FCN 的核心思想是将传统的 CNN(卷积神经网络)中的全连接层替换为卷积层,这样可以接受任意大小的输入图像,并输出同样大小的概率图,其中每个像素点对应于该位置属于某个类别的概率。

FCN 的特点

  1. 任意尺寸输入:由于没有全连接层的限制,FCN 可以处理任意尺寸的输入图像。
  2. 端到端训练:FCN 支持从原始像素到最终预测结果的端到端训练,不需要预先提取特征或者分阶段训练。
  3. 多尺度上下文信息:通过在不同层级使用跳跃结构(skip architecture),FCN 能够结合低层的精细空间信息和高层的语义信息,从而提升分割精度。

FCN 的工作原理

  • 编码器部分:通常基于一个预训练好的分类网络(如 VGG、ResNet 等),移除掉最后的全连接层。这个部分负责提取图像特征,随着网络深度增加,感受野也逐渐扩大,能够捕捉到更大范围内的上下文信息。

  • 解码器部分:这部分用来逐步恢复特征图的空间分辨率,直到与输入图像相同大小。这通常是通过上采样操作完成的,比如转置卷积(Deconvolution 或 Fractionally-strided convolution)。在这个过程中,可以加入来自编码器早期层的特征(即跳跃连接),来帮助保持细节。

  • 跳跃连接(Skip Connections):为了保留更多的位置信息,FCN 会将编码器中较早层的特征图与解码器中对应的层进行融合。这种做法有助于改善边界区域的分割效果。

FCN 的变体和发展

自 FCN 提出以来,出现了许多改进版本,包括但不限于:

  • U-Net:一种具有对称的编码器-解码器结构的网络,广泛应用于医学图像分割领域。
  • DeepLab 系列:通过引入空洞卷积(Atrous Convolution)等技术来增强模型捕捉多尺度信息的能力。
  • PSPNet (Pyramid Scene Parsing Network):利用金字塔池化模块获取全局上下文信息。
  • RefineNet:专注于细节恢复,采用多路径优化策略来传递所有分辨率的信息。

这些模型在不同的应用场景中都有所应用,并且根据特定的任务需求不断进化和发展。

首个端到端的针对像素级预测的全卷积网络

在这里插入图片描述

FCN 的网络结构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Conv的参数量:77512*4096=102760448

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

FCN 的实现(基于Pytorch)

该项目主要是来自pytorch官方torchvision模块中的源码
https://github.com/pytorch/vision/tree/main/torchvision/models/segmentation

1. 环境配置

  • Python3.6/3.7/3.8
  • Pytorch1.10
  • Ubuntu或Centos(Windows暂不支持多GPU训练)
  • 最好使用GPU训练
  • 详细环境配置见requirements.txt

2. 文件结构

  ├── src: 模型的backbone以及FCN的搭建├── train_utils: 训练、验证以及多GPU训练相关模块├── my_dataset.py: 自定义dataset用于读取VOC数据集├── train.py: 以fcn_resnet50(这里使用了Dilated/Atrous Convolution)进行训练├── train_multi_GPU.py: 针对使用多GPU的用户使用├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试├── validation.py: 利用训练好的权重验证/测试数据的mIoU等指标,并生成record_mAP.txt文件└── pascal_voc_classes.json: pascal_voc标签文件

3. 预训练权重下载地址

  • 注意:官方提供的预训练权重是在COCO上预训练得到的,训练时只针对和PASCAL VOC相同的类别进行了训练,所以类别数是21(包括背景)
  • fcn_resnet50: https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth
  • fcn_resnet101: https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth
  • 注意,下载的预训练权重记得要重命名,比如在train.py中读取的是fcn_resnet50_coco.pth文件,
    不是fcn_resnet50_coco-1167a1af.pth

4. 数据集,本例程使用的是PASCAL VOC2012数据集

  • Pascal VOC2012 train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

5. 训练方法

  • 确保提前准备好数据集
  • 确保提前下载好对应预训练模型权重
  • 若要使用单GPU或者CPU训练,直接使用train.py训练脚本
  • 若要使用多GPU训练,使用torchrun --nproc_per_node=8 train_multi_GPU.py指令,nproc_per_node参数为使用GPU数量
  • 如果想指定使用哪些GPU设备可在指令前加上CUDA_VISIBLE_DEVICES=0,3(例如我只要使用设备中的第1块和第4块GPU设备)
  • CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py

6. 注意事项

  • 在使用训练脚本时,注意要将’–data-path’(VOC_root)设置为自己存放’VOCdevkit’文件夹所在的根目录
  • 在使用预测脚本时,要将’weights_path’设置为你自己生成的权重路径。
  • 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改’–num-classes’、‘–aux’、‘–data-path’和’–weights’即可,其他代码尽量不要改动

7. Pytorch官方实现的FCN网络框架图

在这里插入图片描述

8. 完整代码

8.1 src文件目录代码
  • init.py
from .fcn_model import fcn_resnet50, fcn_resnet101
  • backbone.py
import torch
import torch.nn as nndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers)def _forward_impl(self, x):# See note [TorchScript super()]x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef forward(self, x):return self._forward_impl(x)def _resnet(block, layers, **kwargs):model = ResNet(block, layers, **kwargs)return modeldef resnet50(**kwargs):r"""ResNet-50 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)def resnet101(**kwargs):r"""ResNet-101 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
  • fcn_model.py
from collections import OrderedDictfrom typing import Dictimport torch
from torch import nn, Tensor
from torch.nn import functional as F
from .backbone import resnet50, resnet101class IntermediateLayerGetter(nn.ModuleDict):"""Module wrapper that returns intermediate layers from a modelIt has a strong assumption that the modules have been registeredinto the model in the same order as they are used.This means that one should **not** reuse the same nn.Moduletwice in the forward if you want this to work.Additionally, it is only able to query submodules that are directlyassigned to the model. So if `model` is passed, `model.feature1` canbe returned, but not `model.feature1.layer2`.Args:model (nn.Module): model on which we will extract the featuresreturn_layers (Dict[name, new_name]): a dict containing the namesof the modules for which the activations will be returned asthe key of the dict, and the value of the dict is the nameof the returned activation (which the user can specify)."""_version = 2__annotations__ = {"return_layers": Dict[str, str],}def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:if not set(return_layers).issubset([name for name, _ in model.named_children()]):raise ValueError("return_layers are not present in model")orig_return_layers = return_layersreturn_layers = {str(k): str(v) for k, v in return_layers.items()}# 重新构建backbone,将没有使用到的模块全部删掉layers = OrderedDict()for name, module in model.named_children():layers[name] = moduleif name in return_layers:del return_layers[name]if not return_layers:breaksuper(IntermediateLayerGetter, self).__init__(layers)self.return_layers = orig_return_layersdef forward(self, x: Tensor) -> Dict[str, Tensor]:out = OrderedDict()for name, module in self.items():x = module(x)if name in self.return_layers:out_name = self.return_layers[name]out[out_name] = xreturn outclass FCN(nn.Module):"""Implements a Fully-Convolutional Network for semantic segmentation.Args:backbone (nn.Module): the network used to compute the features for the model.The backbone should return an OrderedDict[Tensor], with the key being"out" for the last feature map used, and "aux" if an auxiliary classifieris used.classifier (nn.Module): module that takes the "out" element returned fromthe backbone and returns a dense prediction.aux_classifier (nn.Module, optional): auxiliary classifier used during training"""__constants__ = ['aux_classifier']def __init__(self, backbone, classifier, aux_classifier=None):super(FCN, self).__init__()self.backbone = backboneself.classifier = classifierself.aux_classifier = aux_classifierdef forward(self, x: Tensor) -> Dict[str, Tensor]:input_shape = x.shape[-2:]# contract: features is a dict of tensorsfeatures = self.backbone(x)result = OrderedDict()x = features["out"]x = self.classifier(x)# 原论文中虽然使用的是ConvTranspose2d,但权重是冻结的,所以就是一个bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["out"] = xif self.aux_classifier is not None:x = features["aux"]x = self.aux_classifier(x)# 原论文中虽然使用的是ConvTranspose2d,但权重是冻结的,所以就是一个bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["aux"] = xreturn resultclass FCNHead(nn.Sequential):def __init__(self, in_channels, channels):inter_channels = in_channels // 4layers = [nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(),nn.Dropout(0.1),nn.Conv2d(inter_channels, channels, 1)]super(FCNHead, self).__init__(*layers)def fcn_resnet50(aux, num_classes=21, pretrain_backbone=False):# 'resnet50_imagenet': 'https://download.pytorch.org/models/resnet50-0676ba61.pth'# 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth'backbone = resnet50(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 载入resnet50 backbone预训练权重backbone.load_state_dict(torch.load("resnet50.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return modeldef fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):# 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'# 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'backbone = resnet101(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 载入resnet101 backbone预训练权重backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return model
8.2 train_utils文件目录代码
  • init.py
from .train_and_eval import train_one_epoch, evaluate, create_lr_scheduler
from .distributed_utils import init_distributed_mode, save_on_master, mkdir
  • distributed_utils.py
from collections import defaultdict, deque
import datetime
import time
import torch
import torch.distributed as distimport errno
import osclass SmoothedValue(object):"""Track a series of values and provide access to smoothed values over awindow or the global series average."""def __init__(self, window_size=20, fmt=None):if fmt is None:fmt = "{value:.4f} ({global_avg:.4f})"self.deque = deque(maxlen=window_size)self.total = 0.0self.count = 0self.fmt = fmtdef update(self, value, n=1):self.deque.append(value)self.count += nself.total += value * ndef synchronize_between_processes(self):"""Warning: does not synchronize the deque!"""if not is_dist_avail_and_initialized():returnt = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')dist.barrier()dist.all_reduce(t)t = t.tolist()self.count = int(t[0])self.total = t[1]@propertydef median(self):d = torch.tensor(list(self.deque))return d.median().item()@propertydef avg(self):d = torch.tensor(list(self.deque), dtype=torch.float32)return d.mean().item()@propertydef global_avg(self):return self.total / self.count@propertydef max(self):return max(self.deque)@propertydef value(self):return self.deque[-1]def __str__(self):return self.fmt.format(median=self.median,avg=self.avg,global_avg=self.global_avg,max=self.max,value=self.value)class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, a, b):n = self.num_classesif self.mat is None:# 创建混淆矩阵self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)with torch.no_grad():# 寻找GT中为目标的像素索引k = (a >= 0) & (a < n)# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)inds = n * a[k].to(torch.int64) + b[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):h = self.mat.float()# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)acc_global = torch.diag(h).sum() / h.sum()# 计算每个类别的准确率acc = torch.diag(h) / h.sum(1)# 计算每个类别预测与真实目标的iouiu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.mat)def __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)class MetricLogger(object):def __init__(self, delimiter="\t"):self.meters = defaultdict(SmoothedValue)self.delimiter = delimiterdef update(self, **kwargs):for k, v in kwargs.items():if isinstance(v, torch.Tensor):v = v.item()assert isinstance(v, (float, int))self.meters[k].update(v)def __getattr__(self, attr):if attr in self.meters:return self.meters[attr]if attr in self.__dict__:return self.__dict__[attr]raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))def __str__(self):loss_str = []for name, meter in self.meters.items():loss_str.append("{}: {}".format(name, str(meter)))return self.delimiter.join(loss_str)def synchronize_between_processes(self):for meter in self.meters.values():meter.synchronize_between_processes()def add_meter(self, name, meter):self.meters[name] = meterdef log_every(self, iterable, print_freq, header=None):i = 0if not header:header = ''start_time = time.time()end = time.time()iter_time = SmoothedValue(fmt='{avg:.4f}')data_time = SmoothedValue(fmt='{avg:.4f}')space_fmt = ':' + str(len(str(len(iterable)))) + 'd'if torch.cuda.is_available():log_msg = self.delimiter.join([header,'[{0' + space_fmt + '}/{1}]','eta: {eta}','{meters}','time: {time}','data: {data}','max mem: {memory:.0f}'])else:log_msg = self.delimiter.join([header,'[{0' + space_fmt + '}/{1}]','eta: {eta}','{meters}','time: {time}','data: {data}'])MB = 1024.0 * 1024.0for obj in iterable:data_time.update(time.time() - end)yield objiter_time.update(time.time() - end)if i % print_freq == 0:eta_seconds = iter_time.global_avg * (len(iterable) - i)eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))if torch.cuda.is_available():print(log_msg.format(i, len(iterable), eta=eta_string,meters=str(self),time=str(iter_time), data=str(data_time),memory=torch.cuda.max_memory_allocated() / MB))else:print(log_msg.format(i, len(iterable), eta=eta_string,meters=str(self),time=str(iter_time), data=str(data_time)))i += 1end = time.time()total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print('{} Total time: {}'.format(header, total_time_str))def mkdir(path):try:os.makedirs(path)except OSError as e:if e.errno != errno.EEXIST:raisedef setup_for_distributed(is_master):"""This function disables printing when not in master process"""import builtins as __builtin__builtin_print = __builtin__.printdef print(*args, **kwargs):force = kwargs.pop('force', False)if is_master or force:builtin_print(*args, **kwargs)__builtin__.print = printdef is_dist_avail_and_initialized():if not dist.is_available():return Falseif not dist.is_initialized():return Falsereturn Truedef get_world_size():if not is_dist_avail_and_initialized():return 1return dist.get_world_size()def get_rank():if not is_dist_avail_and_initialized():return 0return dist.get_rank()def is_main_process():return get_rank() == 0def save_on_master(*args, **kwargs):if is_main_process():torch.save(*args, **kwargs)def init_distributed_mode(args):if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:args.rank = int(os.environ["RANK"])args.world_size = int(os.environ['WORLD_SIZE'])args.gpu = int(os.environ['LOCAL_RANK'])elif 'SLURM_PROCID' in os.environ:args.rank = int(os.environ['SLURM_PROCID'])args.gpu = args.rank % torch.cuda.device_count()elif hasattr(args, "rank"):passelse:print('Not using distributed mode')args.distributed = Falsereturnargs.distributed = Truetorch.cuda.set_device(args.gpu)args.dist_backend = 'nccl'print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,world_size=args.world_size, rank=args.rank)setup_for_distributed(args.rank == 0)
  • train_and_eval.py
import torch
from torch import nn
import train_utils.distributed_utils as utilsdef criterion(inputs, target):losses = {}for name, x in inputs.items():# 忽略target中值为255的像素,255的像素是目标边缘或者padding填充losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)if len(losses) == 1:return losses['out']return losses['out'] + 0.5 * losses['aux']def evaluate(model, data_loader, device, num_classes):model.eval()confmat = utils.ConfusionMatrix(num_classes)metric_logger = utils.MetricLogger(delimiter="  ")header = 'Test:'with torch.no_grad():for image, target in metric_logger.log_every(data_loader, 100, header):image, target = image.to(device), target.to(device)output = model(image)output = output['out']confmat.update(target.flatten(), output.argmax(1).flatten())confmat.reduce_from_all_processes()return confmatdef train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler, print_freq=10, scaler=None):model.train()metric_logger = utils.MetricLogger(delimiter="  ")metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))header = 'Epoch: [{}]'.format(epoch)for image, target in metric_logger.log_every(data_loader, print_freq, header):image, target = image.to(device), target.to(device)with torch.cuda.amp.autocast(enabled=scaler is not None):output = model(image)loss = criterion(output, target)optimizer.zero_grad()if scaler is not None:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()lr_scheduler.step()lr = optimizer.param_groups[0]["lr"]metric_logger.update(loss=loss.item(), lr=lr)return metric_logger.meters["loss"].global_avg, lrdef create_lr_scheduler(optimizer,num_step: int,epochs: int,warmup=True,warmup_epochs=1,warmup_factor=1e-3):assert num_step > 0 and epochs > 0if warmup is False:warmup_epochs = 0def f(x):"""根据step数返回一个学习率倍率因子,注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法"""if warmup is True and x <= (warmup_epochs * num_step):alpha = float(x) / (warmup_epochs * num_step)# warmup过程中lr倍率因子从warmup_factor -> 1return warmup_factor * (1 - alpha) + alphaelse:# warmup后lr倍率因子从1 -> 0# 参考deeplab_v2: Learning rate policyreturn (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
8.3 根目录代码
  • pascal_voc_classes.json
{"aeroplane": 1,"bicycle": 2,"bird": 3,"boat": 4,"bottle": 5,"bus": 6,"car": 7,"cat": 8,"chair": 9,"cow": 10,"diningtable": 11,"dog": 12,"horse": 13,"motorbike": 14,"person": 15,"pottedplant": 16,"sheep": 17,"sofa": 18,"train": 19,"tvmonitor": 20
}
  • my_dataset.py
import osimport torch.utils.data as data
from PIL import Imageclass VOCSegmentation(data.Dataset):def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):super(VOCSegmentation, self).__init__()assert year in ["2007", "2012"], "year must be in ['2007', '2012']"root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")assert os.path.exists(root), "path '{}' does not exist.".format(root)image_dir = os.path.join(root, 'JPEGImages')mask_dir = os.path.join(root, 'SegmentationClass')txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)with open(os.path.join(txt_path), "r") as f:file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]assert (len(self.images) == len(self.masks))self.transforms = transformsdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is the image segmentation."""img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:img, target = self.transforms(img, target)return img, targetdef __len__(self):return len(self.images)@staticmethoddef collate_fn(batch):images, targets = list(zip(*batch))batched_imgs = cat_list(images, fill_value=0)batched_targets = cat_list(targets, fill_value=255)return batched_imgs, batched_targetsdef cat_list(images, fill_value=0):# 计算该batch数据中,channel, h, w的最大值max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))batch_shape = (len(images),) + max_sizebatched_imgs = images[0].new(*batch_shape).fill_(fill_value)for img, pad_img in zip(images, batched_imgs):pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)return batched_imgs# dataset = VOCSegmentation(voc_root="/data/", transforms=get_transform(train=True))
# d1 = dataset[0]
# print(d1)
  • validation.py
import os
import torchfrom src import fcn_resnet50
from train_utils import evaluate
from my_dataset import VOCSegmentation
import transforms as Tclass SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")assert os.path.exists(args.weights), f"weights {args.weights} not found."# segmentation nun_classes + backgroundnum_classes = args.num_classes + 1# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2012",transforms=SegmentationPresetEval(520),txt_name="val.txt")num_workers = 8val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,num_workers=num_workers,pin_memory=True,collate_fn=val_dataset.collate_fn)model = fcn_resnet50(aux=args.aux, num_classes=num_classes)model.load_state_dict(torch.load(args.weights, map_location=device)['model'])model.to(device)confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)print(confmat)def parse_args():import argparseparser = argparse.ArgumentParser(description="pytorch fcn training")parser.add_argument("--data-path", default="/data/", help="VOCdevkit root")parser.add_argument("--weights", default="./save_weights/model_29.pth")parser.add_argument("--num-classes", default=20, type=int)parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")parser.add_argument("--device", default="cuda", help="training device")parser.add_argument('--print-freq', default=10, type=int, help='print frequency')args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()if not os.path.exists("./save_weights"):os.mkdir("./save_weights")main(args)
  • transforms.py
import numpy as np
import randomimport torch
from torchvision import transforms as T
from torchvision.transforms import functional as Fdef pad_if_smaller(img, size, fill=0):# 如果图像最小边长小于给定size,则用数值fill进行paddingmin_size = min(img.size)if min_size < size:ow, oh = img.sizepadh = size - oh if oh < size else 0padw = size - ow if ow < size else 0img = F.pad(img, (0, 0, padw, padh), fill=fill)return imgclass Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, target):for t in self.transforms:image, target = t(image, target)return image, targetclass RandomResize(object):def __init__(self, min_size, max_size=None):self.min_size = min_sizeif max_size is None:max_size = min_sizeself.max_size = max_sizedef __call__(self, image, target):size = random.randint(self.min_size, self.max_size)# 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小image = F.resize(image, size)# 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST# 如果是之前的版本需要使用PIL.Image.NEARESTtarget = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)return image, targetclass RandomHorizontalFlip(object):def __init__(self, flip_prob):self.flip_prob = flip_probdef __call__(self, image, target):if random.random() < self.flip_prob:image = F.hflip(image)target = F.hflip(target)return image, targetclass RandomCrop(object):def __init__(self, size):self.size = sizedef __call__(self, image, target):image = pad_if_smaller(image, self.size)target = pad_if_smaller(target, self.size, fill=255)crop_params = T.RandomCrop.get_params(image, (self.size, self.size))image = F.crop(image, *crop_params)target = F.crop(target, *crop_params)return image, targetclass CenterCrop(object):def __init__(self, size):self.size = sizedef __call__(self, image, target):image = F.center_crop(image, self.size)target = F.center_crop(target, self.size)return image, targetclass ToTensor(object):def __call__(self, image, target):image = F.to_tensor(image)target = torch.as_tensor(np.array(target), dtype=torch.int64)return image, targetclass Normalize(object):def __init__(self, mean, std):self.mean = meanself.std = stddef __call__(self, image, target):image = F.normalize(image, mean=self.mean, std=self.std)return image, target
  • train.py
import os
import time
import datetimeimport torchfrom src import fcn_resnet50
from train_utils import train_one_epoch, evaluate, create_lr_scheduler
from my_dataset import VOCSegmentation
import transforms as Tclass SegmentationPresetTrain:def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):min_size = int(0.5 * base_size)max_size = int(2.0 * base_size)trans = [T.RandomResize(min_size, max_size)]if hflip_prob > 0:trans.append(T.RandomHorizontalFlip(hflip_prob))trans.extend([T.RandomCrop(crop_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])self.transforms = T.Compose(trans)def __call__(self, img, target):return self.transforms(img, target)class SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def get_transform(train):base_size = 520crop_size = 480return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)def create_model(aux, num_classes, pretrain=True):model = fcn_resnet50(aux=aux, num_classes=num_classes)if pretrain:weights_dict = torch.load("./src/fcn_resnet50_coco.pth", map_location='cpu')if num_classes != 21:# 官方提供的预训练权重是21类(包括背景)# 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错for k in list(weights_dict.keys()):if "classifier.4" in k:del weights_dict[k]missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)if len(missing_keys) != 0 or len(unexpected_keys) != 0:print("missing_keys: ", missing_keys)print("unexpected_keys: ", unexpected_keys)return modeldef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")batch_size = args.batch_size# segmentation nun_classes + backgroundnum_classes = args.num_classes + 1# 用来保存训练以及验证过程中信息results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> train.txttrain_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=True),txt_name="train.txt")# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=False),txt_name="val.txt")num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True,pin_memory=True,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,num_workers=num_workers,pin_memory=True,collate_fn=val_dataset.collate_fn)model = create_model(aux=args.aux, num_classes=num_classes)model.to(device)params_to_optimize = [{"params": [p for p in model.backbone.parameters() if p.requires_grad]},{"params": [p for p in model.classifier.parameters() if p.requires_grad]}]if args.aux:params = [p for p in model.aux_classifier.parameters() if p.requires_grad]params_to_optimize.append({"params": params, "lr": args.lr * 10})optimizer = torch.optim.SGD(params_to_optimize,lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)scaler = torch.cuda.amp.GradScaler() if args.amp else None# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)if args.resume:checkpoint = torch.load(args.resume, map_location='cpu')model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])args.start_epoch = checkpoint['epoch'] + 1if args.amp:scaler.load_state_dict(checkpoint["scaler"])start_time = time.time()for epoch in range(args.start_epoch, args.epochs):mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)# write into txtwith open(results_file, "a") as f:# 记录每个epoch对应的train_loss、lr以及验证集各指标train_info = f"[epoch: {epoch}]\n" \f"train_loss: {mean_loss:.4f}\n" \f"lr: {lr:.6f}\n"f.write(train_info + val_info + "\n\n")save_file = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"lr_scheduler": lr_scheduler.state_dict(),"epoch": epoch,"args": args}if args.amp:save_file["scaler"] = scaler.state_dict()torch.save(save_file, "save_weights/model_{}.pth".format(epoch))total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print("training time {}".format(total_time_str))def parse_args():import argparseparser = argparse.ArgumentParser(description="pytorch fcn training")parser.add_argument("--data-path", default="/data/", help="VOCdevkit root")parser.add_argument("--num-classes", default=20, type=int)parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")parser.add_argument("--device", default="cuda", help="training device")parser.add_argument("-b", "--batch-size", default=4, type=int)parser.add_argument("--epochs", default=30, type=int, metavar="N",help="number of total epochs to train")parser.add_argument('--lr', default=0.0001, type=float, help='initial learning rate')parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)',dest='weight_decay')parser.add_argument('--print-freq', default=10, type=int, help='print frequency')parser.add_argument('--resume', default='', help='resume from checkpoint')parser.add_argument('--start-epoch', default=0, type=int, metavar='N',help='start epoch')# Mixed precision training parametersparser.add_argument("--amp", default=False, type=bool,help="Use torch.cuda.amp for mixed precision training")args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()if not os.path.exists("./save_weights"):os.mkdir("./save_weights")main(args)
  • train_multi_GPU.py
import time
import os
import datetimeimport torchfrom src import fcn_resnet50
from train_utils import train_one_epoch, evaluate, create_lr_scheduler, init_distributed_mode, save_on_master, mkdir
from my_dataset import VOCSegmentation
import transforms as Tclass SegmentationPresetTrain:def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):min_size = int(0.5 * base_size)max_size = int(2.0 * base_size)trans = [T.RandomResize(min_size, max_size)]if hflip_prob > 0:trans.append(T.RandomHorizontalFlip(hflip_prob))trans.extend([T.RandomCrop(crop_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])self.transforms = T.Compose(trans)def __call__(self, img, target):return self.transforms(img, target)class SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def get_transform(train):base_size = 520crop_size = 480return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)def create_model(aux, num_classes):model = fcn_resnet50(aux=aux, num_classes=num_classes)weights_dict = torch.load("./fcn_resnet50_coco.pth", map_location='cpu')if num_classes != 21:# 官方提供的预训练权重是21类(包括背景)# 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错for k in list(weights_dict.keys()):if "classifier.4" in k:del weights_dict[k]missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)if len(missing_keys) != 0 or len(unexpected_keys) != 0:print("missing_keys: ", missing_keys)print("unexpected_keys: ", unexpected_keys)return modeldef main(args):init_distributed_mode(args)print(args)device = torch.device(args.device)# segmentation nun_classes + backgroundnum_classes = args.num_classes + 1# 用来保存coco_info的文件results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))VOC_root = args.data_path# check voc rootif os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))# load train data set# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> train.txttrain_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=True),txt_name="train.txt")# load validation data set# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=False),txt_name="val.txt")print("Creating data loaders")if args.distributed:train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)else:train_sampler = torch.utils.data.RandomSampler(train_dataset)test_sampler = torch.utils.data.SequentialSampler(val_dataset)train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,sampler=train_sampler, num_workers=args.workers,collate_fn=train_dataset.collate_fn, drop_last=True)val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,sampler=test_sampler, num_workers=args.workers,collate_fn=train_dataset.collate_fn)print("Creating model")# create model num_classes equal background + 20 classesmodel = create_model(aux=args.aux, num_classes=num_classes)model.to(device)if args.sync_bn:model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)model_without_ddp = modelif args.distributed:model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])model_without_ddp = model.moduleparams_to_optimize = [{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},]if args.aux:params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]params_to_optimize.append({"params": params, "lr": args.lr * 10})optimizer = torch.optim.SGD(params_to_optimize,lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)scaler = torch.cuda.amp.GradScaler() if args.amp else None# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)lr_scheduler = create_lr_scheduler(optimizer, len(train_data_loader), args.epochs, warmup=True)# 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练if args.resume:# If map_location is missing, torch.load will first load the module to CPU# and then copy each parameter to where it was saved,# which would result in all processes on the same machine using the same set of devices.checkpoint = torch.load(args.resume, map_location='cpu')  # 读取之前保存的权重文件(包括优化器以及学习率策略)model_without_ddp.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])args.start_epoch = checkpoint['epoch'] + 1if args.amp:scaler.load_state_dict(checkpoint["scaler"])if args.test_only:confmat = evaluate(model, val_data_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)returnprint("Start training")start_time = time.time()for epoch in range(args.start_epoch, args.epochs):if args.distributed:train_sampler.set_epoch(epoch)mean_loss, lr = train_one_epoch(model, optimizer, train_data_loader, device, epoch,lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)confmat = evaluate(model, val_data_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)# 只在主进程上进行写操作if args.rank in [-1, 0]:# write into txtwith open(results_file, "a") as f:# 记录每个epoch对应的train_loss、lr以及验证集各指标train_info = f"[epoch: {epoch}]\n" \f"train_loss: {mean_loss:.4f}\n" \f"lr: {lr:.6f}\n"f.write(train_info + val_info + "\n\n")if args.output_dir:# 只在主节点上执行保存权重操作save_file = {'model': model_without_ddp.state_dict(),'optimizer': optimizer.state_dict(),'lr_scheduler': lr_scheduler.state_dict(),'args': args,'epoch': epoch}if args.amp:save_file["scaler"] = scaler.state_dict()save_on_master(save_file,os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print('Training time {}'.format(total_time_str))if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description=__doc__)# 训练文件的根目录(VOCdevkit)parser.add_argument('--data-path', default='/data/', help='dataset')# 训练设备类型parser.add_argument('--device', default='cuda', help='device')# 检测目标类别数(不包含背景)parser.add_argument('--num-classes', default=20, type=int, help='num_classes')# 每块GPU上的batch_sizeparser.add_argument('-b', '--batch-size', default=4, type=int,help='images per gpu, the total batch size is $NGPU x batch_size')parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")# 指定接着从哪个epoch数开始训练parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')# 训练的总epoch数parser.add_argument('--epochs', default=20, type=int, metavar='N',help='number of total epochs to run')# 是否使用同步BN(在多个GPU之间同步),默认不开启,开启后训练速度会变慢parser.add_argument('--sync_bn', type=bool, default=False, help='whether using SyncBatchNorm')# 数据加载以及预处理的线程数parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)')# 训练学习率,这里默认设置成0.0001,如果效果不好可以尝试加大学习率parser.add_argument('--lr', default=0.0001, type=float,help='initial learning rate')# SGD的momentum参数parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')# SGD的weight_decay参数parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)',dest='weight_decay')# 训练过程打印信息的频率parser.add_argument('--print-freq', default=20, type=int, help='print frequency')# 文件保存地址parser.add_argument('--output-dir', default='./multi_train', help='path where to save')# 基于上次的训练结果接着训练parser.add_argument('--resume', default='', help='resume from checkpoint')# 不训练,仅测试parser.add_argument("--test-only",dest="test_only",help="Only test the model",action="store_true",)# 分布式进程数parser.add_argument('--world-size', default=1, type=int,help='number of distributed processes')parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')# Mixed precision training parametersparser.add_argument("--amp", default=False, type=bool,help="Use torch.cuda.amp for mixed precision training")args = parser.parse_args()# 如果指定了保存文件地址,检查文件夹是否存在,若不存在,则创建if args.output_dir:mkdir(args.output_dir)main(args)
  • predict.py
import os
import time
import jsonimport torch
from torchvision import transforms
import numpy as np
from PIL import Imagefrom src import fcn_resnet50def time_synchronized():torch.cuda.synchronize() if torch.cuda.is_available() else Nonereturn time.time()def main():aux = False  # inference time not need aux_classifierclasses = 20weights_path = "./save_weights/model_2.pth"img_path = "./test.jpeg"palette_path = "./palette.json"assert os.path.exists(weights_path), f"weights {weights_path} not found."assert os.path.exists(img_path), f"image {img_path} not found."assert os.path.exists(palette_path), f"palette {palette_path} not found."with open(palette_path, "rb") as f:pallette_dict = json.load(f)pallette = []for v in pallette_dict.values():pallette += v# get devicesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))# create modelmodel = fcn_resnet50(aux=aux, num_classes=classes+1)# delete weights about aux_classifierweights_dict = torch.load(weights_path, map_location='cpu')['model']for k in list(weights_dict.keys()):if "aux" in k:del weights_dict[k]# load weightsmodel.load_state_dict(weights_dict)model.to(device)# load imageoriginal_img = Image.open(img_path)# from pil image to tensor and normalizedata_transform = transforms.Compose([transforms.Resize(520),transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225))])img = data_transform(original_img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)model.eval()  # 进入验证模式with torch.no_grad():# init modelimg_height, img_width = img.shape[-2:]init_img = torch.zeros((1, 3, img_height, img_width), device=device)model(init_img)t_start = time_synchronized()output = model(img.to(device))t_end = time_synchronized()print("inference+NMS time: {}".format(t_end - t_start))prediction = output['out'].argmax(1).squeeze(0)prediction = prediction.to("cpu").numpy().astype(np.uint8)mask = Image.fromarray(prediction)mask.putpalette(pallette)mask.save("test_result.png")if __name__ == '__main__':main()

测试图片:
在这里插入图片描述

预测结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/5097.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年博客之星主题创作|从零到一:我的技术成长与创作之路

2024年博客之星主题创作&#xff5c;从零到一&#xff1a;我的技术成长与创作之路 个人简介个人主页个人成就热门专栏 历程回顾初来CSDN&#xff1a;怀揣憧憬&#xff0c;开启创作之旅成长之路&#xff1a;从平凡到榜一的蜕变持续分享&#xff1a;打卡基地与成长复盘四年历程&a…

【整体介绍】

ODO&#xff1a;汽车总行驶里程 Chime: 例如安全带没系的报警声音 多屏交互就是中控屏的信息会同步到主驾驶的仪表盘上 面试问题&#xff1a;蓝牙电话协议HFP 音乐协议A2DP 三方通话测试的逻辑

PyTorch使用教程(13)-一文搞定模型的可视化和训练过程监控

一、简介 在现代深度学习的研究和开发中&#xff0c;模型的可视化和监控是不可或缺的一部分。PyTorch&#xff0c;作为一个流行的深度学习框架&#xff0c;通过其丰富的生态系统提供了多种工具来满足这一需求。其中&#xff0c;torch.utils.tensorboard 是一个强大的接口&…

2025寒假备战蓝桥杯01---朴素二分查找的学习

文章目录 1.暴力方法的引入2.暴力解法的思考 与改进3.朴素二分查找的引入4.朴素二分查找的流程5.朴素二分查找的细节6.朴素二分查找的题目 1.暴力方法的引入 对于下面的这个有序的数据元素的组合&#xff0c;我们的暴力解法就是挨个进行遍历操作&#xff0c;一直找到和我们的这…

Qt按钮美化教程

前言 Qt按钮美化主要有三种方式&#xff1a;QSS、属性和自绘 QSS 字体大小 font-size: 18px;文字颜色 color: white;背景颜色 background-color: rgb(10,88,163); 按钮边框 border: 2px solid rgb(114,188,51);文字对齐 text-align: left;左侧内边距 padding-left: 10…

ESP32下FreeRTOS实时操作系统使用

ESP32下FreeRTOS实时操作系统使用 文章目录 ESP32下FreeRTOS实时操作系统使用一、概述二、为什么要使用实时操作系统RTOS&#xff1f;三、FreeRTOS任务3.1 什么是 FreeRTOS 任务&#xff1f;3.2 FreeRTOS 任务的特点3.3 FreeRTOS 任务的生命周期3.4 FreeRTOS 任务的状态3.5 Fre…

包文件分析器 Webpack Bundle Analyzer

webpack-bundle-analyzer 是一个非常有用的工具&#xff0c;用于可视化和分析 Webpack 打包生成的文件。这使得开发者能够更好地理解应用的依赖关系、包的大小&#xff0c;以及优化打包的机会。以下是关于 webpack-bundle-analyzer 的详细介绍&#xff0c;包括它的安装、使用以…

BEVFusion论文阅读

1. 简介 融合激光雷达和相机的信息已经变成了3D目标检测的一个标准&#xff0c;当前的方法依赖于激光雷达传感器的点云作为查询&#xff0c;以利用图像空间的特征。然而&#xff0c;人们发现&#xff0c;这种基本假设使得当前的融合框架无法在发生 LiDAR 故障时做出任何预测&a…

二十七、资源限制-LimitRange

LimitRange生产必备 在调度的时候 requests 比较重要,在运行时 limits 比较重要。 一、产生原因 生产中只有ResourceQuota是不够的 只配置ResourceQuotas的情况下,pod的yaml文件没有配置resources配置,都是0的话,就可以无限配置,永远达不到limit LimitRange做了什么 如…

计算机网络 (54)系统安全:防火墙与入侵检测

前言 计算机网络系统安全是确保网络通信和数据不受未经授权访问、泄露、破坏或篡改的关键。防火墙和入侵检测系统&#xff08;IDS&#xff09;是维护网络系统安全的两大核心组件。 一、防火墙 定义与功能 防火墙是一种用来加强网络之间访问控制的特殊网络互联设备&#xff0c;它…

鸿蒙Harmony json转对象(1)

案例1 运行代码如下 上图的运行结果如下: 附加1 Json_msg interface 案例2 import {JSON } from kit.ArkTS; export interface commonRes {status: numberreturnJSON: ESObject;time: string } export interface returnRes {uid: stringuserType: number; }Entry Component …

光谱相机在智能冰箱的应用原理与优势

食品新鲜度检测 详细可点击查看汇能感知团队实验报告&#xff1a;高光谱成像技术检测食物新鲜度 检测原理&#xff1a;不同新鲜程度的食品&#xff0c;其化学成分和结构会有所不同&#xff0c;在光谱下的反射、吸收等特性也存在差异。例如新鲜肉类和蔬菜中的水分、蛋白质、叶…

BottomNavigationBar组件的用法

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了TextField Widget,本章回中将介绍BottomNavigationBar Widget。闲话休提&#xff0c;让我们一起Talk Flutter吧。 1 概念介绍 我们在本章回中将介绍一个新的Widget:BottomNavigationBar&#xff0c;它就是我们…

总结5..

#include<stdio.h> struct nb {//结构体列队 int x, y;//x为横坐标&#xff0c;y为纵坐标 int s, f;//s为步数&#xff0c;//f为方向 }link[850100]; int n, m, x, y, p, q, f; int hard 1, tail 1; int a[52][52], b[52][52], book[52][52][91]; int main() { …

媒体新闻发稿价格怎么算?移动端发稿价格低的原因有哪些?

对于有过一定发稿经历的朋友&#xff0c;面对不同媒体新闻渠道的发稿价格肯定有所疑惑。尤其同一家媒体&#xff0c;移动端经常比网页端投放渠道的价格要低。到底有哪些方面的原因&#xff0c;导致了这一情况&#xff1f;就让小编来分享下自己的发稿经验。 一、内容展示效果 考…

【Linux系统编程】—— 从零开始实现一个简单的自定义Shell

文章目录 什么是自主shell命令行解释器&#xff1f;实现shell的基础认识全局变量的配置初始化环境变量实现内置命令&#xff08;如 cd 和 echo&#xff09;cd命令&#xff1a;echo命令&#xff1a; 构建命令行提示符获取并解析用户输入的命令执行内置命令与外部命令Shell的主循…

html,css,js的粒子效果

这段代码实现了一个基于HTML5 Canvas的高级粒子效果&#xff0c;用户可以通过鼠标与粒子进行交互。下面是对代码的详细解析&#xff1a; HTML部分 使用<!DOCTYPE html>声明文档类型。<html>标签内包含了整个网页的内容。<head>部分定义了网页的标题&#x…

.Net Core微服务入门系列(一)——项目搭建

系列文章目录 1、.Net Core微服务入门系列&#xff08;一&#xff09;——项目搭建 2、.Net Core微服务入门全纪录&#xff08;二&#xff09;——Consul-服务注册与发现&#xff08;上&#xff09; 3、.Net Core微服务入门全纪录&#xff08;三&#xff09;——Consul-服务注…

【JavaSE】(8) String 类

一、String 类常用方法 1、构造方法 常用的这4种构造方法&#xff1a;直接法&#xff0c;或者传参字符串字面量、字符数组、字节数组。 在 JDK1.8 中&#xff0c;String 类的字符串实际存储在 char 数组中&#xff1a; String 类也重写了 toString 方法&#xff0c;所以可以直…

Linux-C/C++--深入探究文件 I/O (下)(文件共享、原子操作与竞争冒险、系统调用、截断文件)

经过上一章内容的学习&#xff0c;了解了 Linux 下空洞文件的概念&#xff1b;open 函数的 O_APPEND 和 O_TRUNC 标志&#xff1b;多次打开同一文件&#xff1b;复制文件描述符&#xff1b;等内容 本章将会接着探究文件IO&#xff0c;讨论如下主题内容。  文件共享介绍&…