(beta)torch_npu.npu_bmmV2
接口原型
torch_npu.npu_bmmV2(self, mat2, output_sizes) -> Tensor
功能描述
将矩阵“a”乘以矩阵“b”,生成“a*b”。支持FakeTensor模式。
参数说明
- self (Tensor) - 2D或更高维度矩阵张量。数据类型:float16、float32、int32。格式:[ND, NHWC, FRACTAL_NZ]。
- mat2 (Tensor) - 2D或更高维度矩阵张量。数据类型:float16、float32、int32。格式:[ND, NHWC, FRACTAL_NZ]。
- output_sizes (ListInt,默认值为[]) - 输出的shape,用于matmul的反向传播。
支持的型号
Atlas 训练系列产品 Atlas A2 训练系列产品 Atlas A3 训练系列产品 Atlas 推理系列产品
调用示例
示例一:
1 2 3 4 5 | >>> mat1 = torch.randn(10, 3, 4).npu() >>> mat2 = torch.randn(10, 4, 5).npu() >>> res = torch_npu.npu_bmmV2(mat1, mat2, []) >>> res.shape torch.Size([10, 3, 5]) |
示例二:
1 2 3 4 5 6 7 8 9 | //FakeTensor模式 >>> from torch._subclasses.fake_tensor import FakeTensorMode >>> with FakeTensorMode(): ... mat1 = torch.randn(10, 3, 4).npu() ... mat2 = torch.randn(10, 4, 5).npu() ... result = torch_npu.npu_bmmV2(mat1, mat2, []) ... >>> result FakeTensor(..., device='npu:0', size=(10, 3, 5)) |
父主题: torch_npu