npu_grouped_matmul([torch.Tensor] x, [torch.Tensor] weight, [torch.Tensor] bias=None, [torch.Tensor] scale=None, ([torch.Tensor] offset=None, ([torch.Tensor] antiquant_scale=None, [Tensor] antiquant_offset=None, [torch.int] group_list=None, int split_item=0, torch.dtype output_dtype=None) -> [torch.Tensor]
npu_grouped_matmul([torch.Tensor] x, [torch.Tensor] weight, [torch.Tensor] bias, [torch.Tensor] scale, [torch.Tensor] offset, [torch.Tensor] antiquant_scale, [torch.Tensor] antiquant_offset, [int] group_list=None, int split_item=0, torch.dtype output_dtype=None) -> [torch.Tensor]
场景 |
x |
weight |
bias |
scale |
antiquant_scale |
antiquant_offset |
output_dtype |
y |
---|---|---|---|---|---|---|---|---|
非量化 |
torch.float16 |
torch.float16 |
torch.float16 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.float16 |
torch.float16 |
torch.bfloat16 |
torch.bfloat16 |
torch.float32 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.bfloat16 |
torch.bfloat16 |
|
量化 |
torch.int8 |
torch.int8 |
torch.int32 |
torch.int64 |
无需赋值 |
无需赋值 |
torch.int8 |
torch.int8 |
伪量化 |
torch.float16 |
torch.int8 |
torch.float16 |
无需赋值 |
torch.float16 |
torch.float16 |
torch.float16 |
torch.float16 |
torch.bfloat16 |
torch.int8 |
torch.float32 |
无需赋值 |
torch.bfloat16 |
torch.bfloat16 |
torch.bfloat16 |
torch.bfloat16 |
场景 |
说明 |
场景限制 |
---|---|---|
多多多 |
x和weight为多张量,y为多张量。每组数据的张量是独立的。 |
|
单多单 |
x为单张量,weight为多张量,y为单张量。 |
|
单多多 |
x为单张量,weight为多张量,y为多张量。 |
|
多多单 |
x和weight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。 |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch import torch_npu x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] bias1 = torch.randn(256, device='npu', dtype=torch.float16) bias2 = torch.randn(1024, device='npu', dtype=torch.float16) bias3 = torch.randn(128, device='npu', dtype=torch.float16) bias = [bias1, bias2, bias3] group_list = None split_item = 0 npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, group_list=group_list, split_item=split_item) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch import torch_npu x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] bias1 = torch.randn(256, device='npu', dtype=torch.float16) bias2 = torch.randn(1024, device='npu', dtype=torch.float16) bias3 = torch.randn(128, device='npu', dtype=torch.float16) bias = [bias1, bias2, bias3] group_list = None split_item = 0 npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, scale=[], offset=[], antiquant_scale=[], antiquant_offset=[], group_list=group_list, split_item=split_item) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class GMMModel(nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return torch_npu.npu_grouped_matmul(x, weight) def main(): x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] model = GMMModel().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) custom_output = model(x, weight) if __name__ == '__main__': main() |