Faster RCNN Pytorch 实现 代码级 详解

基本结构:

采用VGG提取特征的Faster RCNN.

self.backbone:提取出特征图->features

self.rpn:选出推荐框->proposals

self.roi heads:根据proposals在features上进行抠图->detections

        
features = self.backbone(images.tensors)proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)

1.self.backbone

    def forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

这里pteaures就是提取的特征图、而台teaures是字思形式,由5张特征图组成,这里就是构成了不同的尺度的要求,特征图越小,所映射原图的范围越大。
注:这里的理解很重要,,其实这里能够完全理解,那对图像检测基本就入门了,

五种featureMap:

[1,256,11,21]1:是pytorch四要求的一般会用于batchsize的功效,多少张图片256:通道数
11:  height  高

21:  weight  宽

2.self.rpn

objectness, pred_bbox_deltas = self.head(features)
anchors = self.anchor_generator(images, features)boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

self.head(features):

    def forward(self, x):# type: (List[Tensor])logits = []bbox_reg = []for feature in x:t = F.relu(self.conv(feature))logits.append(self.cls_logits(t))  # 对t分类bbox_reg.append(self.bbox_pred(t))  # 对t回归return logits, bbox_reg

x:就是输出的5张特征图features

objectness, pred_bbox_deltas = self.head(features)

锚框是由特征图上一个像素点在原图上得到的不同尺度的锚框,一般fasterrcnn论文里面是9个尺度
在这里是3

anchors = self.anchor_generator(images, features)

boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

这里的scores是的是前景的概率。(这里一般就是2分类,前景或背景)

top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)

222603—》4693

 for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):boxes = box_ops.clip_boxes_to_image(boxes, img_shape)keep = box_ops.remove_small_boxes(boxes, self.min_size)boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]# non-maximum suppression, independently done per level# NMS的实现keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)# keep only topk scoring predictions# keep就是最终保留的keep = keep[:self.post_nms_top_n()]boxes, scores = boxes[keep], scores[keep]final_boxes.append(boxes)final_scores.append(scores)return final_boxes, final_scores

4693–》1324

1324–》1000
这里的1000是在faster_rcnn.py中设置的

为什么不是2000是因为训练的时候是2000,这里只是测试

proposals, proposal_losses = self.rpn(images, features, targets)

 roi_heads()

        box_features = self.box_roi_pool(features, proposals, image_shapes)box_features = self.box_head(box_features)class_logits, box_regression = self.box_predictor(box_features)#class_logits: 分类概率 和 box_regression :边界框回归

box_roi_pool 规整,为相同尺度的特征图,便于之后的分类与回归
box_roi_pool:两个FC层

        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)# 映射回原图detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  

detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  

补充:pytorch自带detection模块:

import os
import time
import torch.nn as nn
import torch
import random
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
import torch.nn.functional as F
from tools.my_dataset import PennFudanDataset
#from tools.common_tools import set_seed
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F#set_seed(1)  # 设置随机种子BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign','parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow','elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A','handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball','kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket','bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl','banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table','N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone','microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book','clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]def vis_bbox(img, output, classes, max_vis=40, prob_thres=0.4):fig, ax = plt.subplots(figsize=(12, 12))ax.imshow(img, aspect='equal')out_boxes = output_dict["boxes"].cpu()out_scores = output_dict["scores"].cpu()out_labels = output_dict["labels"].cpu()num_boxes = out_boxes.shape[0]for idx in range(0, min(num_boxes, max_vis)):score = out_scores[idx].numpy()bbox = out_boxes[idx].numpy()class_name = classes[out_labels[idx]]if score < prob_thres:continueax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')plt.show()plt.close()class Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, target):for t in self.transforms:image, target = t(image, target)return image, targetclass RandomHorizontalFlip(object):def __init__(self, prob):self.prob = probdef __call__(self, image, target):if random.random() < self.prob:height, width = image.shape[-2:]image = image.flip(-1)bbox = target["boxes"]bbox[:, [0, 2]] = width - bbox[:, [2, 0]]target["boxes"] = bboxreturn image, targetclass ToTensor(object):def __call__(self, image, target):image = F.to_tensor(image)return image, targetif __name__ == "__main__":# configLR = 0.001num_classes = 2batch_size = 1start_epoch, max_epoch = 0, 30train_dir = os.path.join(BASE_DIR, "data", "PennFudanPed")train_transform = Compose([ToTensor(), RandomHorizontalFlip(0.5)])# step 1: datatrain_set = PennFudanDataset(data_dir=train_dir, transforms=train_transform)# 收集batch data的函数def collate_fn(batch):return tuple(zip(*batch))train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_fn)# step 2: modelmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)in_features = model.roi_heads.box_predictor.cls_score.in_featuresmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # replace the pre-trained head with a new onemodel.to(device)# step 3: loss# in lib/python3.6/site-packages/torchvision/models/detection/roi_heads.py# def fastrcnn_loss(class_logits, box_regression, labels, regression_targets)# step 4: optimizer schedulerparams = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=LR, momentum=0.9, weight_decay=0.0005)lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# step 5: Iterationfor epoch in range(start_epoch, max_epoch):model.train()for iter, (images, targets) in enumerate(train_loader):images = list(image.to(device) for image in images)targets = [{k: v.to(device) for k, v in t.items()} for t in targets]# if torch.cuda.is_available():#     images, targets = images.to(device), targets.to(device)loss_dict = model(images, targets)  # images is list; targets is [ dict["boxes":**, "labels":**], dict[] ]losses = sum(loss for loss in loss_dict.values())print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} ".format(epoch, max_epoch, iter + 1, len(train_loader), losses.item()))optimizer.zero_grad()losses.backward()optimizer.step()lr_scheduler.step()# testmodel.eval()# configvis_num = 5vis_dir = os.path.join(BASE_DIR, "data", "PennFudanPed", "PNGImages")img_names = list(filter(lambda x: x.endswith(".png"), os.listdir(vis_dir)))random.shuffle(img_names)preprocess = transforms.Compose([transforms.ToTensor(), ])for i in range(0, vis_num):path_img = os.path.join(vis_dir, img_names[i])# preprocessinput_image = Image.open(path_img).convert("RGB")img_chw = preprocess(input_image)# to deviceif torch.cuda.is_available():img_chw = img_chw.to('cuda')model.to('cuda')# forwardinput_list = [img_chw]with torch.no_grad():tic = time.time()print("input img tensor shape:{}".format(input_list[0].shape))output_list = model(input_list)output_dict = output_list[0]print("pass: {:.3f}s".format(time.time() - tic))# visualizationvis_bbox(input_image, output_dict, COCO_INSTANCE_CATEGORY_NAMES, max_vis=20, prob_thres=0.5)  # for 2 epoch for nms

欢迎点赞   收藏   关注

参考:Pytorch实现Faster-RCNN_pytorch faster rcnn-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/41962.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW时间触发协议

介绍了基于LabVIEW开发的时间触发协议应用&#xff0c;通过实例解析了FlexRay总线的设计与优化。通过技术细节、系统构建和功能实现等方面&#xff0c;探讨了LabVIEW在现代工业通信系统中的应用效能&#xff0c;特别是在提高通信可靠性和实时性方面的贡献。 ​ 项目背景 在工…

Linux 进程3-fork创建子进程继承父进程缓冲区验证

目录 1. fork创建子进程继承父进程缓冲区验证 1.1 write向标准输出&#xff08;终端屏幕&#xff09;写入数据验证 1.1.1 write函数写入数据不带行缓冲刷新 \n 1.1.2 write函数写入数据带行缓冲刷新 \n 1.2 fork创建前执行printf函数 1.2.1 fork创建前执行printf函数带\n…

Celery 全面指南:Python 分布式任务队列详解

Celery 全面指南&#xff1a;Python 分布式任务队列详解 Celery 是一个强大的分布式任务队列/异步任务队列系统&#xff0c;基于分布式消息传递&#xff0c;专注于实时处理&#xff0c;同时也支持任务调度。本文将全面介绍 Celery 的核心功能、应用场景&#xff0c;并通过丰富…

excel 列单元格合并(合并列相同行)

代码 首先自定义注解CellMerge&#xff0c;用于标记哪些属性需要合并&#xff0c;哪个是主键**&#xff08;这里做了一个优化&#xff0c;可以标记多个主键&#xff09;** import org.dromara.common.excel.core.CellMergeStrategy;import java.lang.annotation.*;/*** excel…

mac m4 Homebrew安装MySQL 8.0

1.使用Homebrew安装MySQL8 在终端中输入以下命令来安装MySQL8&#xff1a; brew install mysql8.0 安装完成后&#xff0c;您可以通过以下命令来验证MySQL是否已成功安装&#xff1a; 2.配置mysql环境变量 find / -name mysql 2>/dev/null #找到mysql的安装位置 cd /op…

Wi-SUN技术,强势赋能智慧城市构筑海量IoT网络节点

在智慧城市领域中&#xff0c;当一个智慧路灯项目因信号盲区而被迫增设数百个网关时&#xff0c;当一个传感器网络因入网设备数量爆增而导致系统通信失效时&#xff0c;当一个智慧交通系统因基站故障而导致交通瘫痪时&#xff0c;星型网络拓扑与蜂窝网络拓扑在构建广覆盖与高节…

FALL靶机攻略

1.下载靶机&#xff0c;导入靶机 下载地址&#xff1a;https://download.vulnhub.com/digitalworld/FALL.7z 开启靶机。 2. 靶机、kali设置NAT网卡模式 3. kali扫描NAT网卡段的主机 kali主机 nmap扫描&#xff1a;nmap 192.168.92.1/24 判断出靶机ip是192.168.92.133。开启…

蓝桥杯高频考点——二分(含C++源码)

二分 基本框架整数查找&#xff08;序列二分的模版题 建议先做&#xff09;满分代码及思路solution 子串简写满分代码及思路solution 1&#xff08;暴力 模拟双指针70分&#xff09;solution 2&#xff08;二分 AC&#xff09; 管道满分代码及思路样例解释与思路分析solution 最…

Rust vs. Go: 性能测试(2025)

本内容是对知名性能评测博主 Anton Putra Rust vs. Go (Golang): Performance 2025 内容的翻译与整理, 有适当删减, 相关数据和结论以原作结论为准。 再次对比 Rust 和 Go&#xff0c;但这次我们使用的是最具性能优势的 HTTP 服务器库---Hyper&#xff0c;它基于 Tokio 异步运…

【NUUO 摄像头】(弱口令登录漏洞)

漏洞简介&#xff1a;NUUO 是NUUO公司的一款小型网络硬盘录像机设备。 NUUO NVRMini2 3.0.8及之前版本中存在后门调试文件。远程攻击者可通过向后门文件handle_site_config.php发送特定的请求利用该漏洞执行任意命令。 1.Fofa搜索语句&#xff1a; 在Fofa网站&#xff0c;搜索&…

PyQt6实例_批量下载pdf工具_exe使用方法

目录 前置&#xff1a; 工具使用方法&#xff1a; step one 获取工具 step two 安装 step three 使用 step four 卸载 链接 前置&#xff1a; 1 批量下载pdf工具是基于博文 python_巨潮年报pdf下载-CSDN博客 &#xff0c;将这个需求创建成界面应用&#xff0c;达到可…

matlab 模拟 闪烁体探测器全能峰

clc;clear;close all %% 参数设置 num_events 1e5; % 模拟事件数 E 662e3; % γ射线能量&#xff08;eV&#xff09; Y 38000; % 光产额&#xff08;photon/MeV&#xff0c;NaI(Tl)&#xff09; eta 0.2; % 量子效率 G 1e6; …

启扬RK3568开发板已成功适配OpenHarmony4.0版本

启扬智能IAC-RK3568-Kit开发板支持Debian、Android等常见开源操作系统&#xff0c;目前已完成OpenHarmony4.0开源国产操作系统的适配工作&#xff0c;满足国产化开源操作系统客户的需求。 启扬智能IAC-RK3568-Kit开发板基于瑞芯微RK3568处理器设计&#xff0c;主频最高可达2.0G…

蓝桥与力扣刷题(蓝桥 山)

题目&#xff1a;这天小明正在学数数。 他突然发现有些止整数的形状像一挫 “山”, 比㓚 123565321、145541123565321、145541, 它 们左右对称 (回文) 且数位上的数字先单调不减, 后单调不增。 小朋数了衣久也没有数完, 他惒让你告诉他在区间 [2022,2022222022] 中有 多少个数…

WinDbg. From A to Z! 笔记(一)

原文链接: WinDbg. From A to Z! 文章目录 为什么使用WinDbg为什么通过本书学习底层原理简述Windows的调试工具一览dbghelp.dll -- Windows 调试助手dbgeng.dll -- 调试引擎接口 调试符号 (Debug Symbols)有哪些调试信息生成调试信息匹配调试信息调用堆栈 侵入式与非侵入式异常…

Axure RP 9.0教程: 基于动态面板的元件跟随来实现【音量滑块】

文章目录 引言I 音量滑块的实现步骤添加底层边框添加覆盖层基于覆盖层创建动态面板添加滑块按钮设置滑块拖动效果引言 音量滑块在播放器类APP应用场景相对较广,例如调节视频的亮度、声音等等。 I 音量滑块的实现步骤 添加底层边框 在画布中添加一个矩形框:500 x 32,圆…

Eclipse IDE for ModusToolbox™ 3.4环境通过JLINK调试CYT4BB

使用JLINK在Eclipse IDE for ModusToolbox™ 3.4环境下调试CYT4BB&#xff0c;配置是难点。总结一下在IDE中配置JLINK调试中遇到的坑&#xff0c;以及如何一步一步解决遇到的问题。 1. JFLASH能够正常下载程序 首先要保证通过JFLASH(我使用的J-Flash V7.88c版本)能够通过JLIN…

黑马点评项目

遇到问题&#xff1a; 登录流程 session->JWT->SpringSession->tokenRedis &#xff08;不需要改进为SpringSession&#xff0c;token更广泛&#xff0c;移动端或者前后端分离都可以用&#xff09; SpringSession配置为redis模式后&#xff0c;redis相当于分布式se…

wgcloud怎么实现服务器或者主机的远程关机、重启操作吗

可以&#xff0c;WGCLOUD的指令下发模块可以实现远程关机和重启 使用指令下发模块&#xff0c;重启主机&#xff0c;远程关机&#xff0c;重启agent程序- WGCLOUD

深度解析Spring Boot可执行JAR的构建与启动机制

一、Spring Boot应用打包架构演进 1.1 传统JAR包与Fat JAR对比 传统Java应用的JAR包在依赖管理上存在明显短板&#xff0c;依赖项需要单独配置classpath。Spring Boot创新的Fat JAR&#xff08;又称Uber JAR&#xff09;解决方案通过spring-boot-maven-plugin插件实现了"…