def npu_multiclass_nms()

API接口

def npu_multiclass_nms(multi_bboxes,multi_scores, score_thr=0.05,nms_thr=0.45,max_num=50,score_factors=None):

功能描述

使用NPU API的多类bbox NMS。

参数说明

输出说明

Tuple - (bboxes, labels),shape为(k, 5)和(k, 1)的张量。标签以0为基础。

约束说明

在动态shape条件下,由于NPU op的限制,最多支持20个类别(nmsed_classes)和10000个框(nmsed_boxes)。

示例

调用方式示例:
from torch_npu.contrib.function import npu_multiclass_nms
det_bboxes, det_labels = npu_multiclass_nms(boxes, scores, score_thr=0.3, nms_thr=0.5, max_num=3)
使用示例:
        >>> boxes = torch.randint(1, 255, size=(1000, 4))
        >>> scores = torch.randn(1000, 81)
        >>> boxes = boxes.npu().half()
        >>> scores = scores.npu().half()
        >>> det_bboxes, det_labels = npu_multiclass_nms(boxes, scores, score_thr=0.3, nms_thr=0.5, max_num=3)
        >>> expedt_det_bboxes = torch.tensor([[ 57.0000, 198.8750, 45.9688, 221.8750, 4.1484],[215.0000, 155.0000, 236.8750, 137.0000,3.9023], [208.8750, 221.0000, 228.0000,  17.0000,   3.8867]],dtype=torch.float16)