前言
仅记录学习过程,有问题欢迎讨论
图像分割
- 语义分割不需要区分具体的个体,实例分割需要
反卷积/转置卷积:
-
它并不是正向卷积的完全逆过程。反卷积是一种特殊的正向卷积,先按照一定的比例通过补0
来扩大输入图像的尺寸,接着旋转卷积核,再进行正向卷积。只能还原原图的尺寸,还可提升图像精度。 -
缺点:输出大量无用信息(添0);计算比较消耗资源
语义分割– FCN (生成像素级预测,用于实例分割)
-
FCN将传统卷积网络后面的全连接层换成了卷积层,这样网络输出不再是类别而是heatmap;
同时为了解决因为卷积和池化对图像尺寸的影响,提出使用上采样的方式恢复尺寸 -
对图像进行像素级的分类,在上采样的特征图上进行逐像素分类
-
增大数据尺寸的反卷积(deconv)层。能够输出精细的结果(保持一定精度)
实例分割– Mask R-CNN
- 需要同时检测出目标的位置并且对目标进行分割,目标检测+语义分割
与Faster RCNN的区别:
1)使用ResNet网络作为backbone
2)将 Roi Pooling 层替换成了 RoiAlign;(pooling会有误差,反卷积后误差会很大,所以要替换)
- RoiAlign使用线性插值代替取整操作,固定像素点,使得精度提升
3)添加并列的 Mask 层;
- 添加掩膜,分类卷积,通过RoiAlign的结果获取分类结果
4)引入FPN 和 FCN
- FPN:提取多尺度特征( 生成特征金字塔包含多个尺度的特征图),提升目标检测性能。
- FCN:生成像素级预测,用于实例分割
实现Mask-RCNN网络结构
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2# 定义骨干网络,这里使用 ResNet
class ResNetBackbone(nn.Module):def __init__(self):super(ResNetBackbone, self).__init__()resnet = torchvision.models.resnet50(pretrained=True)self.features = nn.Sequential(*list(resnet.children())[:-2])def forward(self, x):x = self.features(x)return x# 区域生成网络 (RPN)
class RPN(nn.Module):def __init__(self, in_channels, num_anchors):super(RPN, self).__init__()self.conv = nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1)self.cls_layer = nn.Conv2d(512, num_anchors * 2, kernel_size=1, stride=1)self.reg_layer = nn.Conv2d(512, num_anchors * 4, kernel_size=1, stride=1)def forward(self, x):x = F.relu(self.conv(x))cls_scores = self.cls_layer(x)bbox_preds = self.reg_layer(x)cls_scores = cls_scores.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 2)bbox_preds = bbox_preds.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4)return cls_scores, bbox_preds# RoI Align 层
class RoIAlign(nn.Module):def __init__(self, output_size):super(RoIAlign, self).__init__()self.output_size = output_sizedef forward(self, features, rois):roi_features = []for i in range(features.size(0)):roi = rois[i]roi_feature = torchvision.ops.roi_align(features[i].unsqueeze(0), [roi], self.output_size)roi_features.append(roi_feature)roi_features = torch.cat(roi_features, dim=0)return roi_features# Mask 分支
class MaskBranch(nn.Module):def __init__(self, in_channels, num_classes):super(MaskBranch, self).__init__()self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.deconv = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)self.mask_layer = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)def forward(self, x):x = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = F.relu(self.conv3(x))x = F.relu(self.conv4(x))x = F.relu(self.deconv(x))mask_preds = self.mask_layer(x)return mask_preds# Mask R-CNN 模型
class MaskRCNN(nn.Module):def __init__(self, num_classes):super(MaskRCNN, self).__init__()self.backbone = ResNetBackbone()self.rpn = RPN(2048, 9) # 假设使用 9 个锚点self.roi_align = RoIAlign((14, 14)) # RoI Align 到 14x14self.fc1 = nn.Linear(2048 * 14 * 14, 1024)self.fc2 = nn.Linear(1024, 1024)self.cls_layer = nn.Linear(1024, num_classes)self.reg_layer = nn.Linear(1024, num_classes * 4)self.mask_branch = MaskBranch(2048, num_classes)def forward(self, x, rois=None):features = self.backbone(x)cls_scores, bbox_preds = self.rpn(features)if rois is not None:roi_features = self.roi_align(features, rois)roi_features_fc = roi_features.view(roi_features.size(0), -1)fc1 = F.relu(self.fc1(roi_features_fc))fc2 = F.relu(self.fc2(fc1))cls_preds = self.cls_layer(fc2)reg_preds = self.reg_layer(fc2)mask_preds = self.mask_branch(roi_features)return cls_preds, reg_preds, mask_preds, cls_scores, bbox_predselse:return cls_scores, bbox_preds# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, target_paths, transform=None):self.image_paths = image_pathsself.target_paths = target_pathsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = cv2.imread(self.image_paths[idx])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)target = np.load(self.target_paths[idx], allow_pickle=True)if self.transform:image = self.transform(image)return image, target# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 训练函数
def train(model, dataloader, optimizer, criterion_cls, criterion_reg, criterion_mask):model.train()total_loss = 0for images, targets in dataloader:images = images.to(device)targets = [t.to(device) for t in targets]optimizer.zero_grad()cls_preds, reg_preds, mask_preds, cls_scores, bbox_preds = model(images, targets)# 计算分类、回归和掩码损失cls_loss = criterion_cls(cls_preds, targets)reg_loss = criterion_reg(reg_preds, targets)mask_loss = criterion_mask(mask_preds, targets)loss = cls_loss + reg_loss + mask_lossloss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(dataloader)# 评估函数
def evaluate(model, dataloader):model.eval()correct = 0total = 0with torch.no_grad():for images, targets in dataloader:images = images.to(device)targets = [t.to(device) for t in targets]cls_preds, reg_preds, mask_preds, _, _ = model(images)# 计算评估指标,这里可根据具体需求实现# 例如计算 mAP 等return correct / totalif __name__ == "__main__":# 假设的图像和标注文件路径image_paths = ['img/street.jpg', 'img/street.jpg']target_paths = ['target1.npy', 'target2.npy']dataset = CustomDataset(image_paths, target_paths, transform)dataloader = DataLoader(dataset, batch_size=2, shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')num_classes = 2 # 包括背景类model = MaskRCNN(num_classes).to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)criterion_cls = nn.CrossEntropyLoss()criterion_reg = nn.SmoothL1Loss()criterion_mask = nn.BCEWithLogitsLoss() # 用于掩码的损失函数num_epochs = 10for epoch in range(num_epochs):loss = train(model, dataloader, optimizer, criterion_cls, criterion_reg, criterion_mask)print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss}')# 评估accuracy = evaluate(model, dataloader)print(f'Accuracy: {accuracy}')