使用Torchvision框架实现对象检测:从Faster-RCNN模型到自定义数据集,训练模型,完成目标检测任务。

引言

对象检测是一项计算机视觉中的核心任务,其目标是识别图像中的目标并标记它们的位置和类别。在Pytorch生态系统中,Torchvision提供了多种预训练的对象检测模型(如Faster-RCNN、Mask-RCNN等),为开发者快速构建应用提供了便利。

本文将从以下几个方面展开:

  1. Torchvision支持的对象检测模型简介。
  2. Faster-RCNN模型的原理与实现
  3. 自定义数据集的准备与使用

1. Torchvision支持的对象检测模型

Torchvision目前支持以下主流对象检测模型:

  • Faster-RCNN
  • Mask-RCNN
  • RetinaNet

这些模型的特点是提供了预训练权重,可以直接用于COCO等通用场景数据集。它们的输出包括:

  • boxes:目标位置的边界框信息。
  • labels:目标的类别标签。
  • scores:预测的置信分数。

下面通过代码展示如何加载预训练的Faster-RCNN模型并在COCO数据集上进行推理。

示例代码:加载Faster-RCNN模型

import torchvision
from PIL import Image
import torchvision.transforms as T# 加载Faster-RCNN预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()# 预处理图像
image_path = "path/to/your/image.jpg"
image = Image.open(image_path).convert("RGB")
transform = T.Compose([T.ToTensor()])
image_tensor = transform(image)# 推理
with torch.no_grad():predictions = model([image_tensor])# 输出结果
for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']):print(f"位置: {box}, 类别: {label}, 置信度: {score}")

2. Faster-RCNN模型详解

Faster-RCNN是一种经典的两阶段对象检测模型,其主要组成部分包括:

  1. Backbone网络:如ResNet50,用于提取特征。
  2. 区域推荐网络(RPN):生成候选区域。
  3. ROI Pooling:将不同大小的候选区域统一为固定大小。
  4. 分类和回归分支:分类目标并回归边界框。

Faster-RCNN模型的损失函数包括:

  • 分类损失:用于预测目标类别。
  • 位置损失:用于预测边界框的精确位置。

示例代码:在自定义图像上使用Faster-RCNN

import torch# 输入图像
image_tensor = transform(image)# 推理
with torch.no_grad():predictions = model([image_tensor])# 可视化结果
import matplotlib.pyplot as plt
import matplotlib.patches as patchesdef visualize(image, predictions):fig, ax = plt.subplots(1, figsize=(12, 9))ax.imshow(image)for box in predictions[0]['boxes']:x_min, y_min, x_max, y_max = boxrect = patches.Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, linewidth=2, edgecolor='r', facecolor='none')ax.add_patch(rect)plt.show()visualize(image, predictions)

3. 自定义数据集的准备

数据格式

常用的数据集格式包括:

  • Pascal VOC:以XML文件存储标注信息。
  • MS COCO:以JSON文件存储标注信息。

自定义数据集类

为了使用自定义数据集,需继承torch.utils.data.Dataset并实现以下方法:

  • __len__:返回数据集大小。
  • __getitem__:返回单个样本的数据和标注。

示例代码:自定义数据集类

import os
import torch
from PIL import Image
import xml.etree.ElementTree as ETclass CustomDataset(torch.utils.data.Dataset):def __init__(self, root, transforms=None):self.root = rootself.transforms = transformsself.images = list(sorted(os.listdir(os.path.join(root, "images"))))self.annotations = list(sorted(os.listdir(os.path.join(root, "annotations"))))def __getitem__(self, idx):img_path = os.path.join(self.root, "images", self.images[idx])anno_path = os.path.join(self.root, "annotations", self.annotations[idx])img = Image.open(img_path).convert("RGB")tree = ET.parse(anno_path)root = tree.getroot()boxes = []labels = []for obj in root.findall("object"):bbox = obj.find("bndbox")xmin = float(bbox.find("xmin").text)ymin = float(bbox.find("ymin").text)xmax = float(bbox.find("xmax").text)ymax = float(bbox.find("ymax").text)boxes.append([xmin, ymin, xmax, ymax])labels.append(1)  # 假设只有一个类boxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.as_tensor(labels, dtype=torch.int64)target = {"boxes": boxes, "labels": labels}if self.transforms:img = self.transforms(img)return img, targetdef __len__(self):return len(self.images)

4. Faster-RCNN对象检测模型选择与训练

选择Faster-RCNN模型,利用迁移学习技术进行训练模型,检测类别只有三个,cat、dog和背景

lr设置为0.005,

lr_scheduler =
torch.optim.lr_scheduler.StepLR(optimi
zer,
step_size=5,
gamma=0.1)

本实验仅进行8次epoch。

import torch
import torchvision
import os
import sys
from faster_rcnn.engine import train_one_epoch
from faster_rcnn.faster_rcnn_pet_dataset import PetDataset
import faster_rcnn.utils as utilsdef main_train():# 检查是否可以利用GPUtrain_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available.')else:print('CUDA is available!')#cat、dog、and backgroundnum_classes = 3#迁移学习冻结全部层或全链路调优model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True,trainable_backbone_layers= 5,num_classes=num_classes,pretrained_backbone=True)device = torch.device('cuda:0')  # 注意这里应该是 'cuda:0'model.to(device)dataset = PetDataset("")data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True,collate_fn=utils.collate_fn)test_data = PetDataset("")test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True,collate_fn=utils.collate_fn)params = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)num_epochs = 8for epoch in range(num_epochs):train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)lr_scheduler.step()evaluate(model, test_data_loader, device)torch.save(model.state_dict(), "faster_rcnn_pet_model.pt")if __name__ == "__main__":main_train()

5 Faster-RCNN对象检测模型使用

import torchvision
import torch
import cv2 as cv
import numpy as npnum_classes = 3coco_names = {'0': 'background', '1': 'dog', '2': 'cat'}model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes)
model.load_state_dict(torch.load("./faster_rcnn_pet_model.pt"))
model.eval()transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])# 使用GPU
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:model.cuda()def pet_image_detection():image = cv.imread('')blob = transform(image)c, h, w = blob.shapeinput_x = blob.view(1, c, h, w)output = model(input_x.cuda())[0]boxes = output['boxes'].cpu().detach().numpy()scores = output['scores'].cpu().detach().numpy()labels = output['labels'].cpu().detach().numpy()print(boxes.shape, scores.shape, labels.shape)index = 0for x1, y1, x2, y2 in boxes:if scores[index] > 0.5:cv.rectangle(image, (np.int32(x1), np.int32(y1)),(np.int32(x2), np.int32(y2)), (140, 199, 0), 4, 8, 0)label_id = labels[index]label_txt = coco_names[str(label_id)]cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)index += 1cv.imshow("Faster-RCNN Pet Detection", image)cv.imwrite("/home/lichuang/project/Opencv/faster_rcnn/pet1_result.png", image)cv.waitKey(0)cv.destroyAllWindows()if __name__ == '__main__':pet_image_detection()

实验结果可视化:

总结

本文结合代码,介绍了Torchvision框架中对象检测的基本使用方式,包括Faster-RCNN模型的加载与推理,以及自定义数据集的准备与使用,通过设计模型训练,并实现验证,完成一项目标检测小实验,通过这些步骤,可以快速上手并应用到自己的项目中,也可以利用Torchvision框架中其他的模型来进行实验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/4495.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SSM课设-学生管理系统

【课设者】SSM课设-学生管理系统 技术栈: 后端: SpringSpringMVCMybatisMySQLJSP 前端: HtmlCssJavaScriptEasyUIAjax 功能: 学生端: 登陆 学生信息管理 个人信息管理 老师端: 多了教师信息管理 管理员端: 多了班级信息管理 多了年级信息管理 多了系统用户管理

C语言之装甲车库车辆动态监控辅助记录系统

🌟 嗨,我是LucianaiB! 🌍 总有人间一两风,填我十万八千梦。 🚀 路漫漫其修远兮,吾将上下而求索。 C语言之装甲车库车辆动态监控辅助记录系统 目录 一、前言 1.1 (一)…

【STM32-学习笔记-4-】PWM、输入捕获(PWMI)

文章目录 1、PWMPWM配置 2、输入捕获配置3、编码器 1、PWM PWM配置 配置时基单元配置输出比较单元配置输出PWM波的端口 #include "stm32f10x.h" // Device headervoid PWM_Init(void) { //**配置输出PWM波的端口**********************************…

Kinova仿生机械臂Gen3搭载BOTA 力矩传感器SeneOne:彰显机器人触觉 AI 与六维力传感的融合力量

随着工业4.0时代的到来,自动化和智能化成为制造业的趋势。机器人作为实现这一趋势的重要工具,其性能和智能水平直接影响到生产效率和产品质量。然而,传统的机器人系统在应对复杂任务时往往缺乏足够的灵活性和适应性。为了解决这一问题&#x…

有限元分析学习——Anasys Workbanch第一阶段笔记(13)网格单元分类、物理场与自由度概念

目录 0 序言 1 网格单元分类 2 各类单元的应用 3 massage与帮助和查看 4 物理场和自由度 4.1 各种单元自由度 4.2 结构自由度 0 序言 本章主要讲解网格单元的分类及物理场和自由度的相关概念。 1 网格单元分类 按单元的形状分类:实体单元、壳单元和杆梁单元…

python3GUI--仿崩坏三二次元登录页面(附下载地址) By:PyQt5

文章目录 一.前言二.预览三.实现方案1.实现原理1.PyQt52. 具体实现 2.UI设计1.UI组件化、模块化2.UI设计风格思路 3.项目代码结构4.使用方法3.代码分享1.支持跳转网页的QLabel组件2.三角形ICON按钮 四.总结 大小:33.3 …

Pytorch使用教程(12)-如何进行并行训练?

在使用GPU训练大模型时,往往会面临单卡显存不足的情况。这时,通过多卡并行的形式来扩大显存是一个有效的解决方案。PyTorch主要提供了两个类来实现多卡并行:数据并行torch.nn.DataParallel(DP)和模型并行torch.nn.Dist…

电脑换固态硬盘

参考: https://baijiahao.baidu.com/s?id1724377623311611247 一、根据尺寸和缺口可以分为以下几种: 1、M.2 NVME协议的固态 大部分笔记本是22x42MM和22x80MM nvme固态。 在京东直接搜: M.2 2242 M.2 2280 2、msata接口固态 3、NGFF M.…

利用免费GIS工具箱实现高斯泼溅切片,将 PLY 格式转换为 3dtiles

在地理信息系统(GIS)和三维数据处理领域,不同数据格式有其独特应用场景与优势。PLY(Polygon File Format)格式常用于存储多边形网格数据,而 3DTiles 格式在 Web 端三维场景展示等方面表现出色。将 PLY 格式…

【华为路由/交换机的ftp文件操作】

华为路由/交换机的ftp文件操作 PC:10.0.1.1 R1:10.0.1.254 / 10.0.2.254 FTP:10.0.2.1 S1:无配置 在桌面创建FTP-Huawei文件夹,里面创建config/test.txt。 点击上图中的“启动”按钮。 然后ftp到server,…

基于微信小程序的安心陪诊管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

利用rsync备份全网服务器数据

一、项目描述 某公司里有一台Web服务器,里面的数据很重要,但是如果硬盘坏了数据就会丢失,现在领导要求把数据做备份,这样Web服务器数据丢失在可以进行恢复,要求如下: 1、备份要求 每天晚上00点整在Web服…

Mysql 主从复制原理及其工作过程,配置一主两从实验

主从原理:MySQL 主从同步是一种数据库复制技术,它通过将主服务器上的数据更改复制到一个或多个从服务器,实现数据的自动同步。 主从同步的核心原理是将主服务器上的二进制日志复制到从服务器,并在从服务器上执行这些日志中的操作…

Ubuntu 24.04 LTS 空闲硬盘挂载到 文件管理器的 other locations

Ubuntu 24.04 LTS 确认硬盘是否被识别 使用 lsblk 查看信息,其中sda这个盘是我找不到的,途中是挂在好的。 分区和格式化硬盘 如果新硬盘没有分区,你需要先分区并格式化它。假设新硬盘为 /dev/sdb,使用 fdisk 或 parted 对硬盘…

调试Hadoop源代码

个人博客地址:调试Hadoop源代码 | 一张假钞的真实世界 Hadoop版本 Hadoop 2.7.3 调试模式下启动Hadoop NameNode 在${HADOOP_HOME}/etc/hadoop/hadoop-env.sh中设置NameNode启动的JVM参数,如下: export HADOOP_NAMENODE_OPTS"-Xdeb…

JSON-stringify和parse

目录 JSON序列化 JSON反序列化 序列化和反序列化转换 深拷贝 JSON.parse接受参数类型错误导致抛出异常 当有子元素的时候,设置父元素样式的方式 防抖问题 JSON序列化 const obj {name: "John",age: 30,city: "New York",};// 基本用法&…

3 前端(中):JavaScript

文章目录 前言:JavaScript简介一、ECMAscript(JavaScript基本语法)1 JavaScript与html结合方式(快速入门)2 基本知识(1)JavaScript注释(和Java注释一样)(2&am…

服务器一次性部署One API + ChatGPT-Next-Web

服务器一次性部署One API ChatGPT-Next-Web One API ChatGPT-Next-Web 介绍One APIChatGPT-Next-Web docker-compose 部署One API ChatGPT-Next-WebOpen API docker-compose 配置ChatGPT-Next-Web docker-compose 配置docker-compose 启动容器 后续配置 同步发布在个人笔记服…

OSI七层协议——分层网络协议

OSI七层协议,顾名思义,分为七层,实际上七层是不存在的,是人为的进行划分,让人更好的理解 七层协议包括,物理层(我),数据链路层(据),网络层(网),传输层(传输),会话层(会),表示层(表),应用层(用)(记忆口诀->我会用表…

【AI论文】生成式视频模型是否通过观看视频学习物理原理?

摘要:AI视频生成领域正经历一场革命,其质量和真实感在迅速提升。这些进步引发了一场激烈的科学辩论:视频模型是否学习了能够发现物理定律的“世界模型”,或者,它们仅仅是复杂的像素预测器,能够在不理解现实…