引言
对象检测是一项计算机视觉中的核心任务,其目标是识别图像中的目标并标记它们的位置和类别。在Pytorch生态系统中,Torchvision提供了多种预训练的对象检测模型(如Faster-RCNN、Mask-RCNN等),为开发者快速构建应用提供了便利。
本文将从以下几个方面展开:
- Torchvision支持的对象检测模型简介。
- Faster-RCNN模型的原理与实现。
- 自定义数据集的准备与使用。
1. Torchvision支持的对象检测模型
Torchvision目前支持以下主流对象检测模型:
- Faster-RCNN
- Mask-RCNN
- RetinaNet
这些模型的特点是提供了预训练权重,可以直接用于COCO等通用场景数据集。它们的输出包括:
boxes
:目标位置的边界框信息。labels
:目标的类别标签。scores
:预测的置信分数。
下面通过代码展示如何加载预训练的Faster-RCNN模型并在COCO数据集上进行推理。
示例代码:加载Faster-RCNN模型
import torchvision
from PIL import Image
import torchvision.transforms as T# 加载Faster-RCNN预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()# 预处理图像
image_path = "path/to/your/image.jpg"
image = Image.open(image_path).convert("RGB")
transform = T.Compose([T.ToTensor()])
image_tensor = transform(image)# 推理
with torch.no_grad():predictions = model([image_tensor])# 输出结果
for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']):print(f"位置: {box}, 类别: {label}, 置信度: {score}")
2. Faster-RCNN模型详解
Faster-RCNN是一种经典的两阶段对象检测模型,其主要组成部分包括:
- Backbone网络:如ResNet50,用于提取特征。
- 区域推荐网络(RPN):生成候选区域。
- ROI Pooling:将不同大小的候选区域统一为固定大小。
- 分类和回归分支:分类目标并回归边界框。
Faster-RCNN模型的损失函数包括:
- 分类损失:用于预测目标类别。
- 位置损失:用于预测边界框的精确位置。
示例代码:在自定义图像上使用Faster-RCNN
import torch# 输入图像
image_tensor = transform(image)# 推理
with torch.no_grad():predictions = model([image_tensor])# 可视化结果
import matplotlib.pyplot as plt
import matplotlib.patches as patchesdef visualize(image, predictions):fig, ax = plt.subplots(1, figsize=(12, 9))ax.imshow(image)for box in predictions[0]['boxes']:x_min, y_min, x_max, y_max = boxrect = patches.Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, linewidth=2, edgecolor='r', facecolor='none')ax.add_patch(rect)plt.show()visualize(image, predictions)
3. 自定义数据集的准备
数据格式
常用的数据集格式包括:
- Pascal VOC:以XML文件存储标注信息。
- MS COCO:以JSON文件存储标注信息。
自定义数据集类
为了使用自定义数据集,需继承torch.utils.data.Dataset
并实现以下方法:
__len__
:返回数据集大小。__getitem__
:返回单个样本的数据和标注。
示例代码:自定义数据集类
import os
import torch
from PIL import Image
import xml.etree.ElementTree as ETclass CustomDataset(torch.utils.data.Dataset):def __init__(self, root, transforms=None):self.root = rootself.transforms = transformsself.images = list(sorted(os.listdir(os.path.join(root, "images"))))self.annotations = list(sorted(os.listdir(os.path.join(root, "annotations"))))def __getitem__(self, idx):img_path = os.path.join(self.root, "images", self.images[idx])anno_path = os.path.join(self.root, "annotations", self.annotations[idx])img = Image.open(img_path).convert("RGB")tree = ET.parse(anno_path)root = tree.getroot()boxes = []labels = []for obj in root.findall("object"):bbox = obj.find("bndbox")xmin = float(bbox.find("xmin").text)ymin = float(bbox.find("ymin").text)xmax = float(bbox.find("xmax").text)ymax = float(bbox.find("ymax").text)boxes.append([xmin, ymin, xmax, ymax])labels.append(1) # 假设只有一个类boxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.as_tensor(labels, dtype=torch.int64)target = {"boxes": boxes, "labels": labels}if self.transforms:img = self.transforms(img)return img, targetdef __len__(self):return len(self.images)
4. Faster-RCNN对象检测模型选择与训练
选择Faster-RCNN模型,利用迁移学习技术进行训练模型,检测类别只有三个,cat、dog和背景
lr设置为0.005,
lr_scheduler =
torch.optim.lr_scheduler.StepLR(optimi
zer,
step_size=5,
gamma=0.1)
本实验仅进行8次epoch。
import torch
import torchvision
import os
import sys
from faster_rcnn.engine import train_one_epoch
from faster_rcnn.faster_rcnn_pet_dataset import PetDataset
import faster_rcnn.utils as utilsdef main_train():# 检查是否可以利用GPUtrain_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available.')else:print('CUDA is available!')#cat、dog、and backgroundnum_classes = 3#迁移学习冻结全部层或全链路调优model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True,trainable_backbone_layers= 5,num_classes=num_classes,pretrained_backbone=True)device = torch.device('cuda:0') # 注意这里应该是 'cuda:0'model.to(device)dataset = PetDataset("")data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True,collate_fn=utils.collate_fn)test_data = PetDataset("")test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True,collate_fn=utils.collate_fn)params = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)num_epochs = 8for epoch in range(num_epochs):train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)lr_scheduler.step()evaluate(model, test_data_loader, device)torch.save(model.state_dict(), "faster_rcnn_pet_model.pt")if __name__ == "__main__":main_train()
5 Faster-RCNN对象检测模型使用
import torchvision
import torch
import cv2 as cv
import numpy as npnum_classes = 3coco_names = {'0': 'background', '1': 'dog', '2': 'cat'}model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes)
model.load_state_dict(torch.load("./faster_rcnn_pet_model.pt"))
model.eval()transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])# 使用GPU
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:model.cuda()def pet_image_detection():image = cv.imread('')blob = transform(image)c, h, w = blob.shapeinput_x = blob.view(1, c, h, w)output = model(input_x.cuda())[0]boxes = output['boxes'].cpu().detach().numpy()scores = output['scores'].cpu().detach().numpy()labels = output['labels'].cpu().detach().numpy()print(boxes.shape, scores.shape, labels.shape)index = 0for x1, y1, x2, y2 in boxes:if scores[index] > 0.5:cv.rectangle(image, (np.int32(x1), np.int32(y1)),(np.int32(x2), np.int32(y2)), (140, 199, 0), 4, 8, 0)label_id = labels[index]label_txt = coco_names[str(label_id)]cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)index += 1cv.imshow("Faster-RCNN Pet Detection", image)cv.imwrite("/home/lichuang/project/Opencv/faster_rcnn/pet1_result.png", image)cv.waitKey(0)cv.destroyAllWindows()if __name__ == '__main__':pet_image_detection()
实验结果可视化:
总结
本文结合代码,介绍了Torchvision框架中对象检测的基本使用方式,包括Faster-RCNN模型的加载与推理,以及自定义数据集的准备与使用,通过设计模型训练,并实现验证,完成一项目标检测小实验,通过这些步骤,可以快速上手并应用到自己的项目中,也可以利用Torchvision框架中其他的模型来进行实验。