目录
- 一、forward_train调用过程
- 二、forward_train函数详解
- 2.1、extract_feat
- 2.2、self.rpn_head.forward_train
- 2.3、self.roi_head.forward_train
一、forward_train调用过程
书接上文,在函数_call_impl中,最核心的训练过程,self.forward函数,这个函数是在Module子类中实现的,下面是继承关系:
nn.Module->BaseModule->BaseDetector->TwoStageDetector
class Module:def _call_impl(self, *input, **kwargs):if torch._C._get_tracing_state():result = self._slow_forward(*input, **kwargs)else:result = self.forward(*input, **kwargs)
在BaseDetector类中forward函数,实际上是调用了forward_train实现的。
class BaseDetector(BaseModule, metaclass=ABCMeta):def forward(self, img, img_metas, return_loss=True, **kwargs):if torch.onnx.is_in_onnx_export():assert len(img_metas) == 1return self.onnx_export(img[0], img_metas[0])if return_loss:return self.forward_train(img, img_metas, **kwargs)else:return self.forward_test(img, img_metas, **kwargs)
在TwoStageDetector类中forward_train函数,实现了前向传播的整个过程。
class TwoStageDetector(BaseDetector):def forward_train(self,img,img_metas,gt_bboxes,gt_labels,gt_bboxes_ignore=None,gt_masks=None,proposals=None,**kwargs):x = self.extract_feat(img)losses = dict()# RPN forward and lossif self.with_rpn:proposal_cfg = self.train_cfg.get('rpn_proposal',self.test_cfg.rpn)rpn_losses, proposal_list = self.rpn_head.forward_train(x,img_metas,gt_bboxes,gt_labels=None,gt_bboxes_ignore=gt_bboxes_ignore,proposal_cfg=proposal_cfg,**kwargs)losses.update(rpn_losses)else:proposal_list = proposalsroi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,gt_bboxes, gt_labels,gt_bboxes_ignore, gt_masks,**kwargs)losses.update(roi_losses)return losses
二、forward_train函数详解
上面forward_train 函数实现了如下过程:img->backbone->neck->rpn->roi->losses
2.1、extract_feat
def extract_feat(self, img):"""Directly extract features from the backbone+neck."""x = self.backbone(img)if self.with_neck:x = self.neck(x)return x
2.2、self.rpn_head.forward_train
def forward_train(self,x,img_metas,gt_bboxes,gt_labels=None,gt_bboxes_ignore=None,proposal_cfg=None,**kwargs):outs = self(x)if gt_labels is None:loss_inputs = outs + (gt_bboxes, img_metas)else:loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)if proposal_cfg is None:return losseselse:proposal_list = self.get_bboxes(*outs, img_metas=img_metas, cfg=proposal_cfg)return losses, proposal_list
2.3、self.roi_head.forward_train
roi_head在BaseRoIHead的子类中有多种实现,这里简单给出其中一个例子。
@HEADS.register_module()
class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):def forward_train(self,x,img_metas,proposal_list,gt_bboxes,gt_labels,gt_bboxes_ignore=None,gt_masks=None,**kwargs):# assign gts and sample proposalsif self.with_bbox or self.with_mask:num_imgs = len(img_metas)if gt_bboxes_ignore is None:gt_bboxes_ignore = [None for _ in range(num_imgs)]sampling_results = []for i in range(num_imgs):assign_result = self.bbox_assigner.assign(proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],gt_labels[i])sampling_result = self.bbox_sampler.sample(assign_result,proposal_list[i],gt_bboxes[i],gt_labels[i],feats=[lvl_feat[i][None] for lvl_feat in x])sampling_results.append(sampling_result)losses = dict()# bbox head forward and lossif self.with_bbox:bbox_results = self._bbox_forward_train(x, sampling_results,gt_bboxes, gt_labels,img_metas)losses.update(bbox_results['loss_bbox'])# mask head forward and lossif self.with_mask:mask_results = self._mask_forward_train(x, sampling_results,bbox_results['bbox_feats'],gt_masks, img_metas)losses.update(mask_results['loss_mask'])return losses