(beta)torch_npu.npu_rotated_box_encode
接口原型
torch_npu.npu_rotated_box_encode(anchor_box, gt_bboxes, weight) -> Tensor
功能描述
旋转标注框编码。
参数说明
- anchor_box (Tensor) - shape为(B,5,N)的3D输入张量,表示锚点框。“B”表示批处理大小数量,“N”表示标注框数量,值“5”表示“x0”、“x1”、“y0”、“y1”和“angle”。
- gt_bboxes (Tensor) - shape为(B,5,N)数据类型为float32 (float16)的3D张量。
- weight (Tensor,默认值为[1.0, 1.0, 1.0, 1.0, 1.0]) - “x0”、“x1”、“y0”、“y1”和“angle”的浮点列表。
调用示例
>>> anchor_boxes = torch.tensor([[[30.69], [32.6], [45.94], [59.88], [-44.53]]], dtype=torch.float16).to("npu") >>> gt_bboxes = torch.tensor([[[30.44], [18.72], [33.22], [45.56], [8.5]]], dtype=torch.float16).to("npu") >>> weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu() >>> out = torch_npu.npu_rotated_box_encode(anchor_boxes, gt_bboxes, weight) >>> out tensor([[[-0.4253], [-0.5166], [-1.7021], [-0.0162], [ 1.1328]]], device='npu:0', dtype=torch.float16)
父主题: torch_npu