CV(10)--目标检测

前言

仅记录学习过程,有问题欢迎讨论

目标检测

object detection,就是在给定的图片中精确找到物体所在位置,并标注出物体的类别;输出的是分类类别label+物体的外框(x, y, width, height)。

目标检测算法:
1、候选区域/框 + 深度学习分类:通过提取候选区域,并对相应区域进行以深度学习方法为主的
分类的方案,如:
• R-CNN(Selective Search + CNN + SVM)
• SPP-net(ROI Pooling)
• Fast R-CNN(Selective Search + CNN + ROI)
• Faster R-CNN(RPN + CNN + ROI)
2、 基于深度学习的回归方法:YOLO/SSD 等方法

IOU:
一个简单的测量标准,测量两个框之间的相似度,目标和实际的交集/并集,值越大越接近。
TP:预测为正样本,实际也为正样本
TN:预测为负样本,实际也为负样本
FP:预测为正样本,实际为负样本
FN:预测为负样本,实际为正样本

precision(精确度)和recall(召回率):精度是找得对,召回是找的全
precision = TP / (TP + FP)
recall = TP / (TP + FN)
F1 = 2 * (precision * recall) / (precision + recall)

边框回归:
目标是寻找一种关系使得原始窗口经过映射接近真实窗口
Input:
P=(Px,Py,Pw,Ph)
(注:训练阶段输入还包括 Ground Truth)
Output:
需要进行的平移变换和尺度缩放 dx,dy,dw,dh ,或者说是Δx,Δy,Sw,Sh 。
有了这四个变换我们就可以直接得到 Ground Truth。

TWO Stage:
Faster R-CNN(RPN + CNN + ROI)

Faster-RCNN

  1. Conv layers:Faster RCNN首先使用一组基础的conv+relu+pooling层提取 image的feature maps。该feature maps被共享用于后续 RPN层和全连接层。

    • 在Faster RCNN Conv layers中对所有的卷积都做了pad处理( pad=1,即填充一圈0),导致原图 变为 (M+2)x(N+2)大小,再做3x3卷积后输出MxN 。正是这种设置,导致Conv layers中的conv层 不改变输入和输出矩阵大小
    • Conv layers中的pooling层kernel_size=2,stride=2。 这样每个经过pooling层的MxN矩阵,都会变为(M/2)x(N/2)大小
    • 一个MxN大小的矩阵经过Conv layers固定变为(M/16)x(N/16)。 这样Conv layers生成的feature map都可以和原图对应起来
  2. Region Proposal Networks(RPN):RPN网络用于生成region proposals。通过softmax判断anchors属于 positive或者negative,再利用bounding box regression 修正anchors获得精确的proposals。

    • 直接使用RPN生成检测框,是Faster R-CNN的巨 大优势,能极大提升检测框的生成速度。

在这里插入图片描述

- 上面一条通过softmax分类anchors(按特征中心点穷举9个框),获得positive和negative分类-二分类;- 下面一条用于计算对于anchors的bounding box regression偏移量,以获得精确的proposal。- 最后的Proposal层则负责综合positive anchors和对应bounding box regression偏移量获取 proposals,同时剔除太小和超出边界的proposals
  1. Roi Pooling:该层收集输入的feature maps和proposals, 综合这些信息后提取proposal feature maps,送入后续全连接层判定目标类别。

    • proposal是对应MN尺度的,所以先使用spatial_scale参数将其映射回(M/16)(N/16)大小 feature map尺度;
    • 再将每个proposal对应的feature map区域水平分为pooled_w * pooled_h的网格;
    • 对网格的每一份都进行max pooling处理。
  2. Classification:利用proposal feature maps计算 proposal的类别,同时再次bounding box regression获得检测框最终的精确位置。

实现Faster-RCNN网络结构

"""
实现Faster-RCNN
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2# 定义骨干网络,这里使用 ResNet
class ResNetBackbone(nn.Module):def __init__(self):super(ResNetBackbone, self).__init__()resnet = torchvision.models.resnet50(pretrained=True)self.features = nn.Sequential(*list(resnet.children())[:-2])def forward(self, x):x = self.features(x)return x# 区域生成网络 (RPN)
class RPN(nn.Module):def __init__(self, in_channels, num_anchors):super(RPN, self).__init__()self.conv = nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1)#  2表示每个锚点有两种可能的类别:正负样本。通过这层卷积,网络将对每个锚点预测其概率得分。self.cls_layer = nn.Conv2d(512, num_anchors * 2, kernel_size=1, stride=1)# 4 表示对于每个锚点,要预测其边界框的 4 个参数self.reg_layer = nn.Conv2d(512, num_anchors * 4, kernel_size=1, stride=1)def forward(self, x):x = F.relu(self.conv(x))cls_scores = self.cls_layer(x)bbox_preds = self.reg_layer(x)cls_scores = cls_scores.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 2)bbox_preds = bbox_preds.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4)return cls_scores, bbox_preds# RoI 池化层
class RoIPooling(nn.Module):def __init__(self, output_size):super(RoIPooling, self).__init__()self.output_size = output_sizedef forward(self, features, rois):roi_features = []for i in range(features.size(0)):# 包含了感兴趣区域的信息roi = rois[i]# features = (batch_size, channels, height, width),池化到output_size的统一size大小roi_feature = torchvision.ops.roi_pool(features[i].unsqueeze(0), [roi], self.output_size)roi_features.append(roi_feature)roi_features = torch.cat(roi_features, dim=0)return roi_features# Faster R-CNN 模型
class FasterRCNN(nn.Module):def __init__(self, num_classes):super(FasterRCNN, self).__init__()self.backbone = ResNetBackbone()self.rpn = RPN(2048, 9)  # 假设使用 9 个锚点# 池化到 7*7self.roi_pooling = RoIPooling((7, 7))self.fc1 = nn.Linear(2048 * 7 * 7, 1024)self.fc2 = nn.Linear(1024, 1024)self.cls_layer = nn.Linear(1024, num_classes)self.reg_layer = nn.Linear(1024, num_classes * 4)def forward(self, x, rois=None):features = self.backbone(x)cls_scores, bbox_preds = self.rpn(features)if rois is not None:roi_features = self.roi_pooling(features, rois)roi_features = roi_features.view(roi_features.size(0), -1)fc1 = F.relu(self.fc1(roi_features))fc2 = F.relu(self.fc2(fc1))cls_preds = self.cls_layer(fc2)reg_preds = self.reg_layer(fc2)return cls_preds, reg_preds, cls_scores, bbox_predselse:return cls_scores, bbox_preds# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, target_paths, transform=None):self.image_paths = image_pathsself.target_paths = target_pathsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = cv2.imread(self.image_paths[idx])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)target = np.load(self.target_paths[idx], allow_pickle=True)if self.transform:image = self.transform(image)return image, target# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 训练函数
def train(model, dataloader, optimizer, criterion_cls, criterion_reg):model.train()total_loss = 0for images, targets in dataloader:images = images.to(device)targets = [t.to(device) for t in targets]optimizer.zero_grad()cls_preds, reg_preds, cls_scores, bbox_preds = model(images, targets)# 计算分类和回归损失,这里假设 targets 包含真实类别和边界框信息cls_loss = criterion_cls(cls_preds, targets)reg_loss = criterion_reg(reg_preds, targets)loss = cls_loss + reg_lossloss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(dataloader)# 评估函数
def evaluate(model, dataloader):model.eval()correct = 0total = 0with torch.no_grad():for images, targets in dataloader:images = images.to(device)targets = [t.to(device) for t in targets]cls_preds, reg_preds, _, _ = model(images)# 计算评估指标,这里可根据具体需求实现# 例如计算 mAP 等return correct / totalif __name__ == "__main__":# 假设的图像和标注文件路径image_paths = ['img/street.jpg', 'img/street.jpg']target_paths = ['target1.npy', 'target2.npy']dataset = CustomDataset(image_paths, target_paths, transform)dataloader = DataLoader(dataset, batch_size=2, shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')num_classes = 2  # 包括背景类model = FasterRCNN(num_classes).to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)criterion_cls = nn.CrossEntropyLoss()criterion_reg = nn.SmoothL1Loss()num_epochs = 10for epoch in range(num_epochs):loss = train(model, dataloader, optimizer, criterion_cls, criterion_reg)print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss}')# 评估accuracy = evaluate(model, dataloader)print(f'Accuracy: {accuracy}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/2566.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx 如何设置 Upgrade-Insecure-Requests 报头 ?

Upgrade-Insecure-Requests 报头是一种 web 浏览器向服务器发出信号的机制,它倾向于接收安全 (HTTPS) 资源。添加此报头有助于在受支持的浏览器上将不安全的请求升级为安全的请求。 Step 1: 定位 Nginx 配置 主 nginx 配置文件通常位于 /etc/nginx/nginx.conf特定…

3.Qt Quick-QML地图引擎之v4.3版本(新增动态轨迹线/海图/天地图街道/天地图卫星)

在上个版本Qt Quick-QML地图引擎之v4版本(新增多模型切换/3D模型欧拉角模拟)_qt加载3d地图-CSDN博客更新了3D模拟功能,在4.3版本增加动态轨迹线、三个地图(海图/天地图街道/天地图卫星)。 4.3版本已经支持qt6 cmake版本,而4.3版本以下支持qt5版本&#x…

Linux-----线程操作(创建)

目录 创建线程 示例&#xff1a; 创建线程 #include <pthread.h>/*** 创建一个新线程* * pthread_t *thread: 指向线程标识符的指针,线程创建成功时,用于存储新创建线程的线程标识符* const pthread_attr_t *attr: pthead_attr_t结构体,这个参数可以用来设置线程的属性…

我要成为算法高手-DFS篇

目录 题目1&#xff1a;计算布尔二叉树的值题目2&#xff1a;求根节点到叶子结点数字之和题目3&#xff1a;二叉树剪枝题目4&#xff1a;验证二叉搜索树题目4&#xff1a;二叉搜索树中第 K 小的元素题目5&#xff1a;二叉树的所有路径 题目1&#xff1a;计算布尔二叉树的值 23…

学习threejs,使用FlyControls相机控制器

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.FlyControls 相机控制…

Life Long Learning(李宏毅)机器学习 2023 Spring HW14 (Boss Baseline)

1. 终身学习简介 神经网络的典型应用场景是,我们有一个固定的数据集,在其上训练并获得模型参数,然后将模型应用于特定任务而无需进一步更改模型参数。 然而,在许多实际工程应用中,常见的情况是系统可以不断地获取新数据,例如 Web 应用程序中的新用户数据或自动驾驶中的…

Multi-Agent如何设计

文章小结 研究背景和目的 在单一大语言模型长期主导人工智能领域的背景下&#xff0c;多智能体系统在对话任务解决中逐渐崭露头角。 虽然先前的研究已经展示了多智能体系统在推理任务和创造性工作中的潜力&#xff0c;但对于其在对话范式方面的局限性以及单个智能体的影响&am…

(即插即用模块-Attention部分) 四十四、(ICIP 2022) HWA 半小波注意力

文章目录 1、Half Wavelet Attention2、代码实现 paper&#xff1a;HALFWAVELET ATTENTION ON M-NET FOR LOW-LIGHT IMAGE ENHANCEMENT Code&#xff1a;https://github.com/FanChiMao/HWMNet 1、Half Wavelet Attention 传统的图像增强方法主要关注图像在空间域的特征信息&am…

SpringBoot+Lombok项目实体属性名xXxx格式,前端接收不到

问题解析 今天发现后端传给前端的实体类中&#xff0c;有属性为xXxxx格式的&#xff0c;前端也使用相同名称接收&#xff0c;结果却不显示值&#xff01;研究了一会发现接口请求回来后&#xff0c;原xXxxx的属性名&#xff0c;会被转为全小写。具体原因为&#xff1a;使用Lombo…

Spring Boot教程之五十五:Spring Boot Kafka 消费者示例

Spring Boot Kafka 消费者示例 Spring Boot 是 Java 编程语言中最流行和使用最多的框架之一。它是一个基于微服务的框架&#xff0c;使用 Spring Boot 制作生产就绪的应用程序只需很少的时间。Spring Boot 可以轻松创建独立的、生产级的基于 Spring 的应用程序&#xff0c;您可…

网络安全——常用语及linux系统

一、网络安全概念及法规 网络安全&#xff1a;网络空间安全 cyber security 信息系统&#xff1a;由计算机硬件、网络和通信设备、计算机软件、信息资源、信息用户和规章制度组成的已处理信息流为目的的人机一体化系统 信息系统安全三要素&#xff08;CIA&#xff09; 保密…

Windows 正确配置android adb调试的方法

下载适用于 Windows 的 SDK Platform-Tools https://developer.android.google.cn/tools/releases/platform-tools?hlzh-cn 设置系统变量&#xff0c;路径为platform-tools文件夹的绝对路径 点击Path添加环境变量 %adb%打开终端输入adb shell 这就成功了&#xff01;

如何优化Elasticsearch大文档查询?

记录一次业务复杂场景下DSL优化的过程 背景 B端商城业务有一个场景就是客户可见的产品列表是需要N多闸口及各种其它逻辑组合过滤的&#xff0c;各种闸口数据及产品数据都是存储在ES的(有的是独立索引&#xff0c;有的是作为产品属性存储在产品文档上)。 在实际使用的过程中&a…

使用 WPF 和 C# 将纹理应用于三角形

此示例展示了如何将纹理应用于三角形,以使场景比覆盖纯色的场景更逼真。以下是为三角形添加纹理的基本步骤。 创建一个MeshGeometry3D对象。像往常一样定义三角形的点和法线。通过向网格的TextureCoordinates集合添加值来设置三角形的纹理坐标。创建一个使用想要显示的纹理的 …

探索 Transformer²:大语言模型自适应的新突破

目录 一、来源&#xff1a; 论文链接&#xff1a;https://arxiv.org/pdf/2501.06252 代码链接&#xff1a;SakanaAI/self-adaptive-llms 论文发布时间&#xff1a;2025年1月14日 二、论文概述&#xff1a; 图1 Transformer 概述 图2 训练及推理方法概述 图3 基于提示的…

Android Studio历史版本包加载不出来,怎么办?

为什么需要下载历史版本呢&#xff1f; 虽然官网推荐使用最新版本&#xff0c;但是最新版本如果自己碰到问题&#xff0c;根本找不到答案&#xff0c;所以博主这里推荐使用历史版本&#xff01;&#xff01;&#xff01; Android Studio历史版本包加载不出来&#xff1f; 下…

citrix netscaler13.1 重写负载均衡响应头(基础版)

在 Citrix NetScaler 13.1 中&#xff0c;Rewrite Actions 用于对负载均衡响应进行修改&#xff0c;包括替换、删除和插入 HTTP 响应头。这些操作可以通过自定义策略来完成&#xff0c;帮助你根据需求调整请求内容。以下是三种常见的操作&#xff1a; 1. Replace (替换响应头)…

STM32 FreeRTOS移植

目录 FreeRTOS源码结构介绍 获取源码 1、 官网下载 2、 Github下载 源码结构介绍 源码整体结构 FreeRTOS文件夹结构 Source文件夹结构如下 portable文件夹结构 RVDS文件夹 MemMang文件夹 FreeRTOS在基于寄存器项目中移植步骤 目录添加源码文件 工程添加源码文件 …

[Qt]常用控件介绍-按钮类控件-QPushButton、QRedioButton、QCheckBox、QToolButton控件

目录 1.QPushButton按钮 介绍 属性 Demo&#xff1a;键盘方向键控制人物移动 2.Redio Button按钮 属性 clicked、pressed、released、toggled区别 单选按钮的分组 Demo&#xff1a;点餐小程序 3.CheckBox按钮 属性 Demo&#xff1a;获取今天的形成计划 4.ToolBu…

SpringBoot链接Kafka

一、SpringBoot生产者 &#xff08;1&#xff09;修改SpringBoot核心配置文件application.propeties, 添加生产者相关信息 # 连接 Kafka 集群 spring.kafka.bootstrap-servers192.168.134.47:9093# SASL_PLAINTEXT 和 SCRAM-SHA-512 认证配置 spring.kafka.properties.securi…