(十八)mmdetection源码解读:forward_train

目录

  • 一、forward_train调用过程
  • 二、forward_train函数详解
    • 2.1、extract_feat
    • 2.2、self.rpn_head.forward_train
    • 2.3、self.roi_head.forward_train

一、forward_train调用过程

书接上文,在函数_call_impl中,最核心的训练过程,self.forward函数,这个函数是在Module子类中实现的,下面是继承关系:
nn.Module->BaseModule->BaseDetector->TwoStageDetector

class Module:def _call_impl(self, *input, **kwargs):if torch._C._get_tracing_state():result = self._slow_forward(*input, **kwargs)else:result = self.forward(*input, **kwargs)

在BaseDetector类中forward函数,实际上是调用了forward_train实现的。

class BaseDetector(BaseModule, metaclass=ABCMeta):def forward(self, img, img_metas, return_loss=True, **kwargs):if torch.onnx.is_in_onnx_export():assert len(img_metas) == 1return self.onnx_export(img[0], img_metas[0])if return_loss:return self.forward_train(img, img_metas, **kwargs)else:return self.forward_test(img, img_metas, **kwargs)

在TwoStageDetector类中forward_train函数,实现了前向传播的整个过程。

class TwoStageDetector(BaseDetector):def forward_train(self,img,img_metas,gt_bboxes,gt_labels,gt_bboxes_ignore=None,gt_masks=None,proposals=None,**kwargs):x = self.extract_feat(img)losses = dict()# RPN forward and lossif self.with_rpn:proposal_cfg = self.train_cfg.get('rpn_proposal',self.test_cfg.rpn)rpn_losses, proposal_list = self.rpn_head.forward_train(x,img_metas,gt_bboxes,gt_labels=None,gt_bboxes_ignore=gt_bboxes_ignore,proposal_cfg=proposal_cfg,**kwargs)losses.update(rpn_losses)else:proposal_list = proposalsroi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,gt_bboxes, gt_labels,gt_bboxes_ignore, gt_masks,**kwargs)losses.update(roi_losses)return losses

二、forward_train函数详解

上面forward_train 函数实现了如下过程:img->backbone->neck->rpn->roi->losses

2.1、extract_feat

    def extract_feat(self, img):"""Directly extract features from the backbone+neck."""x = self.backbone(img)if self.with_neck:x = self.neck(x)return x

2.2、self.rpn_head.forward_train

    def forward_train(self,x,img_metas,gt_bboxes,gt_labels=None,gt_bboxes_ignore=None,proposal_cfg=None,**kwargs):outs = self(x)if gt_labels is None:loss_inputs = outs + (gt_bboxes, img_metas)else:loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)if proposal_cfg is None:return losseselse:proposal_list = self.get_bboxes(*outs, img_metas=img_metas, cfg=proposal_cfg)return losses, proposal_list

2.3、self.roi_head.forward_train

roi_head在BaseRoIHead的子类中有多种实现,这里简单给出其中一个例子。

@HEADS.register_module()
class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):def forward_train(self,x,img_metas,proposal_list,gt_bboxes,gt_labels,gt_bboxes_ignore=None,gt_masks=None,**kwargs):# assign gts and sample proposalsif self.with_bbox or self.with_mask:num_imgs = len(img_metas)if gt_bboxes_ignore is None:gt_bboxes_ignore = [None for _ in range(num_imgs)]sampling_results = []for i in range(num_imgs):assign_result = self.bbox_assigner.assign(proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],gt_labels[i])sampling_result = self.bbox_sampler.sample(assign_result,proposal_list[i],gt_bboxes[i],gt_labels[i],feats=[lvl_feat[i][None] for lvl_feat in x])sampling_results.append(sampling_result)losses = dict()# bbox head forward and lossif self.with_bbox:bbox_results = self._bbox_forward_train(x, sampling_results,gt_bboxes, gt_labels,img_metas)losses.update(bbox_results['loss_bbox'])# mask head forward and lossif self.with_mask:mask_results = self._mask_forward_train(x, sampling_results,bbox_results['bbox_feats'],gt_masks, img_metas)losses.update(mask_results['loss_mask'])return losses

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/131982.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

请求与响应以及REST风格

目录 请求与响应请求参数参数传递 五种类型参数传递普通参数POJO数据类型嵌套POJO类型参数数组类型参数集合类型参数 JSON数据传输参数JSON普通数组JSON对象数据JSON对象数组知识点1:EnableWebMvc知识点2:RequestBodyRequestBody与RequestParam区别日期类…

[SICTF 2023] webmisc

文章目录 webBaby_PHP涉及知识点 我全都要RCE你能跟得上我的speed吗 miscPixel_art攻破这个压缩包&#xff01; web Baby_PHP 涉及知识点 php解析特性apache换行解析漏洞无参RCE 源代码 <?php highlight_file(__FILE__); error_reporting(0);$query $_SERVER[QUERY_ST…

OpenCV Series : Target Box Outline Border

角点 P1 (255, 000, 000) P2 (000, 255, 000) P3 (000, 000, 255) P4 (000, 000, 000)垂直矩形框 rect cv2.minAreaRect(cnt)targetColor roi_colortargetThickness 1targetColor (255, 255, 255)if lineVerbose:if True:cv2.line(ph…

做机器视觉工程师,其实挺没意思的

3.康耐视VisionPro高级脚本系列教程-3.脚本编辑错误和运行错误调试方法&#xff0c;break和Contitinuee的差别_哔哩哔哩_bilibili 其实人生就是“有时有意思&#xff0c;有时没意思”。 心里有太多的不甘心&#xff0c;太多的苦水&#xff0c;是没法再吃学习的苦&#xff0c…

市场调查中的信度和效度分析原理及python实现示例

市场调查中的信度和效度分析 1.量表信度分析1.1 内部一致性信度&#xff1a;克朗巴赫α系数原理1.2 python实现示例 2.量表效度分析2.1 内容效度2.1.1 原理2.1.2 python实现示例 2.2 准则效度2.2.1 原理2.2.2 python实现示例 2.3 结构效度2.3.1 原理2.3.2 python实现示例 3.量表…

[PyTorch][chapter 54][GAN- 1]

前言&#xff1a; GAN playground: Experiment with Generative Adversarial Networks in your browser 生成对抗网络&#xff08;Generative Adversarial Nets&#xff0c;GAN&#xff09;是一种基于对抗学习的深度生成模型&#xff0c;最早由Ian Goodfellow于2014年在《Gener…

selenium.chrome怎么写扩展拦截或转发请求?

Selenium WebDriver 是一组开源 API&#xff0c;用于自动测试 Web 应用程序&#xff0c;利用它可以通过代码来控制chrome浏览器&#xff01; 有时候我们需要mock接口的返回&#xff0c;或者拦截和转发请求&#xff0c;今天就来实现这个功能。 代码已开源&#xff1a; https:/…

使用java连接Libvirtd

基于springboot web 一、依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId>&l…

用c语言编写出三底模型

以下是一个用C语言实现三底模型的示例代码。这个程序通过循环遍历输入的股票数据&#xff0c;判断是否出现三底形态&#xff0c;如果是&#xff0c;则输出买入信号&#xff0c;否则输出卖出信号。 c语言 #include <stdio.h> #include <stdlib.h> // 判断是否出现…

Java项目---图片服务器

图片服务器--->服务器&#xff08;图床&#xff09; 核心功能&#xff1a;上传图片、展示图片等 比如&#xff1a;编写博客时我们会插入图片&#xff0c;本质上是往文章中放了一个链接&#xff08;URL&#xff09;&#xff0c;这个URL资源在另外一个服务器上。 核心知识点…

FPGA 纯VHDL解码 IMX214 MIPI 视频,2路视频拼接输出,提供vivado工程源码和技术支持

目录 1、前言免责声明 2、我这里已有的 MIPI 编解码方案3、本 MIPI CSI2 模块性能及其优越性4、详细设计方案设计原理框图IMX214 摄像头及其配置D-PHY 模块CSI-2-RX 模块Bayer转RGB模块伽马矫正模块VDMA图像缓存Video Scaler 图像缓存HDMI输出 5、vivado工程详解PL端FPGA硬件设…

无涯教程-JavaScript - INFO函数

描述 INFO函数返回有关当前操作环境的信息。 语法 INFO (type_text) 争论 Argument描述Required/OptionalType_text 指定要返回的信息类型的文本。 下表给出了Type_text的值和相应的返回信息。 Required Type_text 返回的信息"目录" 当前目录或文件夹的路径。&qu…

【Proteus仿真】【STM32单片机】四驱寻迹避障小车

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 系统运行后&#xff0c;LCD1602显示红外、超声波检测状态和距离、小车运行状态。可通过K1键可手动切换模式&#xff0c;寻迹、避障、蓝牙遥控&#xff1b;也可通过蓝牙发送指令切换模式&#xff1b; 当处…

系统架构设计之道,论如何构建一个资金账户系统

&#x1f449;导读 资金账户是互联网和金融业务中非常常见的系统&#xff0c;尤其是在电商、支付等业务中必不可少。资金账户系统本身其核心模块的整体架构往往并不复杂&#xff0c;但其对于资金安全和可用性的要求非常高&#xff0c;导致建设好一个资金账户系统并不容易。本文…

【Spring Boot】有这一文就够了

作者简介 前言 作者之前写过一个Spring Boot的系列&#xff0c;包含自动装配原理、MVC、安全、监控、集成数据库、集成Redis、日志、定时任务、异步任务等内容&#xff0c;本文将会一文拉通来总结这所有内容&#xff0c;不骗人&#xff0c;一文快速入门Spring Boot。 专栏地址…

MySQL安装validate_password_policy插件

功能介绍 validate_password_policy 是插件用于验证密码强度的策略。该参数可以设定三种级别&#xff1a;0代表低&#xff0c;1代表中&#xff0c;2代表高。 validate_password_policy 主要影响密码的强度检查级别&#xff1a; 0/LOW&#xff1a;只检查密码长度。 1/MEDIUM&am…

YashanDB:潜心实干,数据库核心技术突破没有捷径可走

都说数据库是三大基础软件中的一块硬骨头&#xff0c;技术门槛高、研发周期长、工程要求高&#xff0c;市场长期被几大巨头所把持。 因此&#xff0c;实现突破一直是中国数据库产业的夙愿。自上个世纪80年代起&#xff0c;中国数据库产业走过艰辛坎坷的四十余载&#xff0c;终…

vue组件库开发,webpack打包,发布npm

做一个像elment-ui一样的vue组件库 那多好啊&#xff01;这是我前几年就想做的 但webpack真的太难用&#xff0c;也许是我功力不够 今天看到一个视频&#xff0c;早上6-13点&#xff0c;终于实现了&#xff0c;呜呜 感谢视频的分享-来龙去脉-大家可以看这个视频&#xff1a;htt…

美东一公司的郁闷面试题

说是题目可以用不同的语言&#xff0c;但是貌似 Java 是多线程的&#xff0c;用 Java 写肯定容易不少。 但&#xff0c;觉得这个题目用多线程简直是有点脱了裤子放屁。 完整题目内容 题目的网站内容如下&#xff1a; Please complete the following challenge in one of th…

【自动驾驶决策规划】POMDP之Introduction

文章目录 前言Markov PropertyMarkov ChainHidden Markov ModelMarkov Decision ProcessPartially Observable Markov Decision ProcessBackground on Solving POMDPsPOMDP Value Iteration Example 推荐阅读与参考 前言 本文是我学习POMDP相关的笔记&#xff0c;由于个人能力…