onnxruntime推理
使用mmdeploy导出onnx模型:
from mmdeploy.apis import torch2onnx
from mmdeploy.backend.sdk.export_info import export2SDKimg = 'demo.JPEG'
work_dir = './work_dir/onnx/mask_rcnn'
save_file = './end2end.onnx'
deploy_cfg = 'mmdeploy/configs/mmdet/detection/detection_onnxruntime_dynamic.py'
model_cfg = 'mmdetection/configs/mask_rcnn/mask-rcnn_r50_fpn_ms-poly-3x_coco.py'
model_checkpoint = 'checkpoints/mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth'
device = 'cpu'# 1. convert model to onnx
torch2onnx(img, work_dir, save_file, deploy_cfg, model_cfg, model_checkpoint, device)# 2. extract pipeline info for sdk use (dump-info)
export2SDK(deploy_cfg, model_cfg, work_dir, pth=model_checkpoint, device=device)
可以用netron查看onnx模型的结构:
也可以通过下面的代码尝试手动导出onnx模型:
import cv2
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
from mmdet.apis import init_detector, inference_detector
from mmengine.config import Config
from mmcv.ops import batched_nms
from mmcv.ops.point_sample import bilinear_grid_sample
from mmdet.structures.bbox import bbox2roi
from mmdet.models.layers import multiclass_nmsconfig_file = './configs/mask_rcnn/mask-rcnn_r50_fpn_ms-poly-3x_coco.py'
checkpoint_file = '../checkpoints/mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth'
img_meta = {'batch_input_shape': (800, 800), 'pad_shape': (800, 800), 'ori_shape': (800, 800), 'scale_factor': (1.0, 1.0), 'img_shape': (800, 800)}
test_cfg = Config({'score_thr': 0.05, 'nms': {'type': 'nms', 'iou_threshold': 0.5}, 'max_per_img': 100, 'mask_thr_binary': 0.5})
num_classes = 80
nms_pre = 1000class MaskRCNN(torch.nn.Module): def __init__(self):super().__init__()self.model = init_detector(config_file, checkpoint_file, device='cpu')self.base_sizes = [4, 8, 16, 32, 64]self.ratios = torch.tensor([0.5000, 1.0000, 2.0000])self.scales = torch.tensor([8.])self.strides = [(4, 4), (8, 8), (16, 16), (32, 32), (64, 64)]self.base_anchors = self.gen_base_anchors()def gen_base_anchors(self):multi_level_base_anchors = []for i, base_size in enumerate(self.base_sizes):multi_level_base_anchors.append(self.gen_single_level_base_anchors(base_size, scales=self.scales, ratios=self.ratios))return multi_level_base_anchorsdef gen_single_level_base_anchors(self, base_size, scales, ratios):w = base_sizeh = base_sizeh_ratios = torch.sqrt(ratios)w_ratios = 1 / h_ratiosws = (w * w_ratios[:, None] * scales[None, :]).view(-1)hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)base_anchors = [-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs]base_anchors = torch.stack(base_anchors, dim=-1)return base_anchorsdef _meshgrid(self, x, y):xx = x.repeat(y.shape[0])yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)return xx, yydef single_level_grid_priors(self, featmap_size, level_idx, dtype: torch.dtype = torch.float32):base_anchors = self.base_anchors[level_idx].to(dtype)feat_h, feat_w = featmap_sizestride_w, stride_h = self.strides[level_idx]shift_x = torch.arange(0, feat_w).to(dtype) * stride_wshift_y = torch.arange(0, feat_h).to(dtype) * stride_hshift_xx, shift_yy = self._meshgrid(shift_x, shift_y)shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)all_anchors = base_anchors[None, :, :] + shifts[:, None, :]all_anchors = all_anchors.view(-1, 4)return all_anchorsdef grid_priors(self, featmap_sizes, dtype):multi_level_anchors = []for i in range(len(featmap_sizes)):anchors = self.single_level_grid_priors(featmap_sizes[i], level_idx=i, dtype=dtype)multi_level_anchors.append(anchors)return multi_level_anchorsdef predict_by_feat(self, cls_scores, bbox_preds):num_levels = len(cls_scores)featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]mlvl_priors = self.grid_priors(featmap_sizes,dtype=cls_scores[0].dtype)cls_score_list = [cls_scores[i][0].detach() for i in range(len(cls_scores))]bbox_pred_list = [bbox_preds[i][0].detach() for i in range(len(bbox_preds))]mlvl_bbox_preds = []mlvl_valid_priors = []mlvl_scores = []level_ids = []for level_idx, (cls_score, bbox_pred, priors) in enumerate(zip(cls_score_list, bbox_pred_list, mlvl_priors)):bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)cls_score = cls_score.permute(1, 2, 0).reshape(-1, 1)scores = cls_score.sigmoid()scores = torch.squeeze(scores)if 0 < nms_pre < scores.shape[0]:_, topk_inds = scores.topk(nms_pre)# ranked_scores, rank_inds = scores.sort(descending=True)# topk_inds = rank_inds[:nms_pre]# scores = ranked_scores[:nms_pre]scores = scores[topk_inds]bbox_pred = bbox_pred[topk_inds, :]priors = priors[topk_inds]mlvl_bbox_preds.append(bbox_pred)mlvl_valid_priors.append(priors)mlvl_scores.append(scores)level_ids.append(scores.new_full((scores.size(0), ), level_idx, dtype=torch.long))bbox_pred = torch.cat(mlvl_bbox_preds)priors = torch.cat(mlvl_valid_priors, dim=0)bboxes = self.model.rpn_head.bbox_coder.decode(priors, bbox_pred, max_shape=img_meta['img_shape'])scores = torch.cat(mlvl_scores)level_ids = torch.cat(level_ids)keep_idxs = torchvision.ops.batched_nms(bboxes, scores, level_ids, 0.7)bboxes = bboxes[keep_idxs][:nms_pre]priors = priors[keep_idxs][:nms_pre]scores = scores[keep_idxs][:nms_pre]return bboxes, scores, keep_idxsdef multiclass_nms(self, multi_bboxes, multi_scores, score_thr, nms_cfg, max_num, box_dim):num_classes = multi_scores.size(1) - 1bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim)scores = multi_scores[:, :-1]labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)labels = labels.view(1, -1).expand_as(scores)bboxes = bboxes.reshape(-1, box_dim)scores = scores.reshape(-1)labels = labels.reshape(-1)valid_mask = scores > score_thrinds = valid_mask.nonzero(as_tuple=False).squeeze(1)bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]#dets, keep = batched_nms(bboxes, scores, labels, test_cfg['nms'])keep = torchvision.ops.batched_nms(bboxes, scores, labels, test_cfg['nms']['iou_threshold'])dets = torch.cat([bboxes[keep], torch.unsqueeze(scores[keep], 1)], -1)if keep.shape[0] > max_num:dets = dets[:max_num]keep = keep[:max_num]return dets, labels[keep]def predict_bbox(self, x, rpn_results_list):proposals = [rpn_results_list[0]]rois = bbox2roi(proposals)bbox_results = self.model.roi_head._bbox_forward(x, rois)cls_scores = bbox_results['cls_score']bbox_preds = bbox_results['bbox_pred']scores = F.softmax(cls_scores, dim=-1)img_shape = img_meta['img_shape']num_rois = rois.size(0)#roi = rois[0].repeat_interleave(num_classes, dim=0)roi = rois.repeat(1, num_classes).view(-1, rois.shape[1])bbox_pred = bbox_preds.view(-1, 4)bboxes = self.model.roi_head.bbox_head.bbox_coder.decode(roi[..., 1:], bbox_pred, max_shape=img_shape)box_dim = bboxes.size(-1)bboxes = bboxes.view(num_rois, -1)#det_bboxes, det_labels = multiclass_nms(bboxes, scores, test_cfg['score_thr'], test_cfg['nms'], test_cfg['max_per_img'], box_dim)det_bboxes, det_labels = self.multiclass_nms(bboxes, scores, test_cfg['score_thr'], test_cfg['nms'], test_cfg['max_per_img'], box_dim)return det_bboxes[:, :-1], det_bboxes[:, -1], det_labelsdef _do_paste_mask(self, masks, boxes, img_h, img_w):device = masks.devicex0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(dtype=torch.int32)x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)x0, y0, x1, y1 = boxes[0][0].reshape(1, 1), boxes[0][1].reshape(1, 1), boxes[0][2].reshape(1, 1), boxes[0][3].reshape(1, 1) #torch.split(boxes, 1, dim=1) # each is Nx1N = masks.shape[0]img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5img_y = (img_y - y0) / (y1 - y0) * 2 - 1img_x = (img_x - x0) / (x1 - x0) * 2 - 1#if not torch.onnx.is_in_onnx_export():if torch.isinf(img_x).any():inds = torch.where(torch.isinf(img_x))img_x[inds] = 0if torch.isinf(img_y).any():inds = torch.where(torch.isinf(img_y))img_y[inds] = 0gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))grid = torch.stack([gx, gy], dim=3)#img_masks = F.grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False)img_masks = bilinear_grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False)return img_masks[:, 0], (slice(y0_int.to(dtype=torch.int64), y1_int.to(dtype=torch.int64)), slice(x0_int.to(dtype=torch.int64), x1_int.to(dtype=torch.int64)))def mask_roi_extractor(self, feats, rois):out_size = self.model.roi_head.mask_roi_extractor.roi_layers[0].output_sizenum_levels = len(feats)roi_feats = feats[0].new_zeros(rois.size(0), self.model.roi_head.mask_roi_extractor.out_channels, *out_size)target_lvls = self.model.roi_head.mask_roi_extractor.map_roi_levels(rois, num_levels)for i in range(num_levels):mask = target_lvls == iinds = mask.nonzero(as_tuple=False).squeeze(1)rois_ = rois[inds]roi_feats_t = self.model.roi_head.mask_roi_extractor.roi_layers[i](feats[i], rois_)roi_feats[inds] = roi_feats_treturn roi_featsdef predict_mask(self, x, img_meta, results_list):bboxes = results_list[0]scores = results_list[1]labels = results_list[2]rois = bbox2roi([bboxes])# mask_results = self.model.roi_head._mask_forward(x, rois)mask_feats = self.mask_roi_extractor(x[:4], rois)mask_preds = self.model.roi_head.mask_head(mask_feats)scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat((1, 2))img_h, img_w = img_meta['ori_shape'][:2]mask_preds = mask_preds.sigmoid()bboxes /= scale_factor#return bboxes, scores, labels, mask_predsnum_chunks = len(mask_preds)chunks = torch.chunk(torch.arange(num_chunks), num_chunks)threshold = test_cfg['mask_thr_binary']im_mask = torch.zeros(num_chunks, img_h, img_w, dtype=torch.bool)mask_preds = mask_preds[range(num_chunks), labels][:, None]for inds in chunks:masks_chunk, spatial_inds = self._do_paste_mask(mask_preds[inds], bboxes[inds], img_h, img_w)masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)im_mask[(inds, ) + spatial_inds] = masks_chunkreturn bboxes, scores, labels, im_maskdef forward(self, x):x = self.model.backbone(x)x = self.model.neck(x)outs = self.model.rpn_head(x)#rpn_results_list = self.model.rpn_head.predict_by_feat(*outs, batch_img_metas=batch_img_metas, rescale=False)rpn_results_list = self.predict_by_feat(*outs)#results_list = self.model.roi_head.predict_bbox(x, batch_img_metas, rpn_results_list, rcnn_test_cfg=test_cfg, rescale=False)results_list = self.predict_bbox(x, rpn_results_list)#results_list = self.model.roi_head.predict_mask(x, batch_img_metas, results_list, rescale=True)results_list = self.predict_mask(x, img_meta, results_list)#np.savetxt("my.txt",results_list[-1].detach().numpy().reshape(-1,results_list[-1].shape[-1]),fmt="%.02f")return results_listmodel = MaskRCNN().eval()
input = torch.zeros(1, 3, 800, 800, device='cpu')
torch.onnx.export(model, input, "./mmdetection/maskrcnn.onnx", opset_version=11)# import onnx
# from onnxsim import simplify
# onnx_model = onnx.load("maskrcnn.onnx") # load onnx modell
# model_simp, check = simplify(onnx_model)
# assert check, "Simplified ONNX model could not be validated"
# onnx.save(model_simp, "maskrcnn_sim.onnx")
导出的模型结构如下:
但是存在一些问题,如无法输入动态尺寸,推理速度较慢等,故下面推理代码中的模型均通过mmdeploy导出的模型:
python推理代码:
import cv2
import numpy as np
import onnxruntime
import json
import random
import torch
import torch.nn.functional as F
from pathlib import Path
from copy import deepcopyresize_shape = (1333, 800) confidence_threshold = 0.5mask_threshold = 0.5def filter_box(outputs, scale): flag = outputs[0][..., 4] > confidence_threshold #删除置信度小于confidence_threshold的detboxes = outputs[0][flag] boxes[..., [0, 2]] *= scale[1]boxes[..., [1, 3]] *= scale[0]class_ids = outputs[1][flag].reshape(-1, 1) outputs[2] = np.expand_dims(outputs[2][flag], axis=1)output = np.concatenate((boxes, class_ids), axis=1) return outputdef resize_keep_ratio(image, img_scale):h, w = image.shape[0], image.shape[1]max_long_edge = max(img_scale)max_short_edge = min(img_scale)scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))scale_w = int(w * float(scale_factor ) + 0.5)scale_h = int(h * float(scale_factor ) + 0.5)img_new = cv2.resize(image, (scale_w, scale_h))return img_newdef draw_boxes(image, box_data):boxes = box_data[...,:4].astype(np.int32) scores = box_data[...,4]classes = box_data[...,5].astype(np.int32) for box, score, cl in zip(boxes, scores, classes):x1, y1, x2, y2 = boxcv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 1)cv2.putText(image, 'class{0} {1:.6f}'.format(cl, score), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)def draw_binary_masks(img, binary_masks, alphas=0.5):binary_masks = binary_masks.astype('uint8') * 255alphas = [alphas] * binary_masks.shape[0]for binary_mask, alpha in zip(binary_masks, alphas):binary_mask_complement = cv2.bitwise_not(binary_mask)rgb = np.zeros_like(img)rgb[...] = [random.randint(0, 256), random.randint(0, 256), random.randint(0, 256)]rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask)img_complement = cv2.bitwise_and(img, img, mask=binary_mask_complement)rgb = rgb + img_complementimg = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0)cv2.imwrite("result.jpg", img)if __name__=="__main__":onnx_session = onnxruntime.InferenceSession('end2end.onnx', providers=['CPUExecutionProvider'])input_name = []for node in onnx_session.get_inputs():input_name.append(node.name)output_name=[]for node in onnx_session.get_outputs():output_name.append(node.name)image = cv2.imread("bus.jpg")image_resize = resize_keep_ratio(image, resize_shape) scale = (image.shape[0]/image_resize.shape[0], image.shape[1]/image_resize.shape[1])pad_shape = (np.ceil(image_resize.shape[1]/32)*32, np.ceil(image_resize.shape[0]/32)*32) pad_x, pad_y = int(pad_shape[0]-image_resize.shape[1]), int(pad_shape[1]-image_resize.shape[0])image_pad = cv2.copyMakeBorder(image_resize, 0, pad_y, 0, pad_x, cv2.BORDER_CONSTANT, value=0)input = image_pad[:, :, ::-1].transpose(2, 0, 1).astype(dtype=np.float32) #BGR2RGB和HWC2CHWinput[0,:] = (input[0,:] - 123.675) / 58.395 input[1,:] = (input[1,:] - 116.28) / 57.12input[2,:] = (input[2,:] - 103.53) / 57.375# tensor = torch.from_numpy(input) #(3, 750, 1333)# input = F.pad(tensor, (0, pad_x, 0, pad_y, 0, 0), value=0).numpy() input = np.expand_dims(input, axis=0)inputs = {} for name in input_name:inputs[name] = inputoutputs = onnx_session.run(None, inputs)boxes = filter_box(outputs, scale)masks = torch.from_numpy(outputs[2]) x0, y0, x1, y1 = torch.split(torch.from_numpy(boxes[..., :4]), 1, dim=1)img_y = torch.arange(0, image.shape[0]).to(torch.float32) + 0.5img_x = torch.arange(0, image.shape[1]).to(torch.float32) + 0.5img_y = (img_y - y0) / (y1 - y0) * 2 - 1img_x = (img_x - x0) / (x1 - x0) * 2 - 1gx = img_x[:, None, :].expand(masks.shape[0], img_y.size(1), img_x.size(1))gy = img_y[:, :, None].expand(masks.shape[0], img_y.size(1), img_x.size(1))grid = torch.stack([gx, gy], dim=3).to(torch.float32) img_masks = F.grid_sample(masks.to(torch.float32), grid, align_corners=False) masks_chunk = img_masks[:, 0]masks_chunk = (masks_chunk >= mask_threshold).to(dtype=torch.bool)mask_results = masks_chunk.detach().cpu().numpy()draw_boxes(image, boxes)draw_binary_masks(image, mask_results)
tensorrt推理
使用mmdeploy导出engine模型:
from mmdeploy.apis import torch2onnx
from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt
from mmdeploy.backend.sdk.export_info import export2SDK
import osimg = 'bus.jpg'
work_dir = './work_dir/trt/maskrcnn'
save_file = './end2end.onnx'
deploy_cfg = './configs/mmdet/instance-seg/instance-seg_tensorrt_static-768x1344.py'
model_cfg = '../mmdetection-3.3.0/configs/mask_rcnn/mask-rcnn_r50_fpn_ms.py'
model_checkpoint = 'checkpoints/mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth'
device = 'cuda'torch2onnx(img, work_dir, save_file, deploy_cfg, model_cfg, model_checkpoint, device)# 2. convert IR to tensorrt
onnx_model = os.path.join(work_dir, save_file)
save_file = 'end2end.engine'
model_id = 0
device = 'cuda'
onnx2tensorrt(work_dir, save_file, model_id, deploy_cfg, onnx_model, device)# 3. extract pipeline info for sdk use (dump-info)
export2SDK(deploy_cfg, model_cfg, work_dir, pth=model_checkpoint, device=device)
手动编写tensorrt推理脚本:
import cv2
import ctypes
import numpy as np
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver as cuda
import random
import torch
import torch.nn.functional as F
from pathlib import Path
import pycocotools.mask as mask_utilresize_shape = (1333, 800) confidence_threshold = 0.5mask_threshold = 0.5def filter_box(outputs, scale): flag = outputs[0][..., 4] > confidence_threshold #删除置信度小于confidence_threshold的detboxes = outputs[0][flag] boxes[..., [0, 2]] *= scale[1]boxes[..., [1, 3]] *= scale[0]class_ids = outputs[1][flag].reshape(-1, 1) outputs[2] = np.expand_dims(outputs[2][flag], axis=1)output = np.concatenate((boxes, class_ids), axis=1) return outputdef resize_keep_ratio(image, size):height, width = image.shape[0], image.shape[1]width_new, height_new = sizeif width / height >= width_new / height_new:img_new = cv2.resize(image, (width_new, round(height * width_new / width)))else:img_new = cv2.resize(image, (round(width * height_new / height), height_new))return img_newdef draw_boxes(image, box_data):boxes = box_data[...,:4].astype(np.int32) scores = box_data[...,4]classes = box_data[...,5].astype(np.int32) for box, score, cl in zip(boxes, scores, classes):x1, y1, x2, y2 = boxcv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 1)cv2.putText(image, 'class{0} {1:.6f}'.format(cl, score), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)def draw_binary_masks(img, binary_masks, alphas=0.2):binary_masks = binary_masks.astype('uint8') * 255alphas = [alphas] * binary_masks.shape[0]for binary_mask, alpha in zip(binary_masks, alphas):binary_mask_complement = cv2.bitwise_not(binary_mask)rgb = np.zeros_like(img)rgb[...] = [random.randint(0, 256), random.randint(0, 256), random.randint(0, 256)]rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask)img_complement = cv2.bitwise_and(img, img, mask=binary_mask_complement)rgb = rgb + img_complementimg = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0)cv2.imwrite("result.jpg", img)if __name__=="__main__":ctypes.cdll.LoadLibrary('mmdeploy_tensorrt_ops.dll')logger = trt.Logger(trt.Logger.WARNING) with open("end2end.engine", "rb") as f, trt.Runtime(logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())context = engine.create_execution_context()h_input = cuda.pagelocked_empty(trt.volume((1, 3, 768, 1344)), dtype=np.float32)h_output0 = cuda.pagelocked_empty(trt.volume((1, 100, 5)), dtype=np.float32)h_output1 = cuda.pagelocked_empty(trt.volume((1, 100)), dtype=np.int32)h_output2 = cuda.pagelocked_empty(trt.volume((1, 100, 28, 28)), dtype=np.float32)d_input = cuda.mem_alloc(h_input.nbytes)d_output0 = cuda.mem_alloc(h_output0.nbytes)d_output1 = cuda.mem_alloc(h_output1.nbytes)d_output2 = cuda.mem_alloc(h_output2.nbytes)stream = cuda.Stream()image = cv2.imread('bus.jpg')image_resize = resize_keep_ratio(image, resize_shape) scale = (image.shape[0]/image_resize.shape[0], image.shape[1]/image_resize.shape[1])pad_shape = (np.ceil(image_resize.shape[1]/32)*32, np.ceil(image_resize.shape[0]/32)*32) pad_x, pad_y = int(pad_shape[0]-image_resize.shape[1]), int(pad_shape[1]-image_resize.shape[0])input = image_resize[:, :, ::-1].transpose(2, 0, 1).astype(dtype=np.float32) #BGR2RGB和HWC2CHWinput[0,:] = (input[0,:] - 123.675) / 58.395 input[1,:] = (input[1,:] - 116.28) / 57.12input[2,:] = (input[2,:] - 103.53) / 57.375tensor = torch.from_numpy(input) #(3, 800, 1280)input = F.pad(tensor, (0, pad_x, 0, pad_y, 0, 0), value=0).numpy() input = np.expand_dims(input, axis=0)h_input = input.flatten()with engine.create_execution_context() as context:context.set_input_shape("input", (1, 3, 768, 1344))cuda.memcpy_htod_async(d_input, h_input, stream)context.execute_async_v2(bindings=[int(d_input), int(d_output0), int(d_output1), int(d_output2)], stream_handle=stream.handle)cuda.memcpy_dtoh_async(h_output0, d_output0, stream)cuda.memcpy_dtoh_async(h_output1, d_output1, stream)cuda.memcpy_dtoh_async(h_output2, d_output2, stream)stream.synchronize() h_output = []h_output.append(h_output0.reshape(1, 100, 5))h_output.append(h_output1.reshape(1, 100))h_output.append(h_output2.reshape(1, 100, 28, 28))boxes = filter_box(h_output, scale)masks = torch.from_numpy(h_output[2]) #torch.Size([1, 1, 28, 28]) x0, y0, x1, y1 = torch.split(torch.from_numpy(boxes[..., :4]), 1, dim=1)img_y = torch.arange(0, image.shape[0]).to(torch.float32) + 0.5img_x = torch.arange(0, image.shape[1]).to(torch.float32) + 0.5img_y = (img_y - y0) / (y1 - y0) * 2 - 1img_x = (img_x - x0) / (x1 - x0) * 2 - 1gx = img_x[:, None, :].expand(masks.shape[0], img_y.size(1), img_x.size(1))gy = img_y[:, :, None].expand(masks.shape[0], img_y.size(1), img_x.size(1))grid = torch.stack([gx, gy], dim=3).to(torch.float32) img_masks = F.grid_sample(masks.to(torch.float32), grid, align_corners=False) masks_chunk = img_masks[:, 0]masks_chunk = (masks_chunk >= mask_threshold).to(dtype=torch.bool)mask_results = masks_chunk.detach().cpu().numpy()draw_boxes(image, boxes)draw_binary_masks(image, mask_results)