(beta)torch_npu.contrib.function.npu_single_level_responsible_flags
接口原型
torch_npu.contrib.function.npu_single_level_responsible_flags(featmap_size,gt_bboxes,stride,num_base_anchors):
功能描述
使用NPU OP在单个特征图中生成锚点的responsible flags。
参数说明
- featmap_size (Tuple(Int)) - 模型总维度。
- gt_bboxes (Tensor) - 并行attention heads。
- stride (Tuple(Int),默认值为None) - key的特性总数。
- num_base_anchors (Int,默认值为None) - values的特性总数。
输出说明
torch.Tensor - 单层特征图中每个锚点的有效标志。输出大小为[featmap_size[0] * featmap_size[1] * num_base_anchors]。
支持的型号
- Atlas 训练系列产品
- Atlas A2 训练系列产品
- Atlas 推理系列产品
调用示例
>>> from torch_npu.contrib.function import npu_single_level_responsible_flags >>> featmap_sizes = [[10, 10], [20, 20], [40, 40]] >>> stride = [[32, 32], [16, 16], [8, 8]] >>> gt_bboxes = torch.randint(0, 512, size=(128, 4)) >>> num_base_anchors = 3 >>> featmap_level = len(featmap_sizes) >>> for i in range(featmap_level): ... gt_bboxes = gt_bboxes.npu() >>> out = npu_single_level_responsible_flags(featmap_sizes[i],gt_bboxes,stride[i],num_base_anchors) >>> print(out.shape, out.max(), out.min())
父主题: torch_npu.contrib