torch_npu.npu_grouped_matmul
功能描述
- 算子功能:npu_grouped_matmul是一种对多个矩阵乘法(matmul)操作进行分组计算的高效方法。该API实现了对多个矩阵乘法操作的批量处理,通过将具有相同形状或相似形状的矩阵乘法操作组合在一起,减少内存访问开销和计算资源的浪费,从而提高计算效率。
接口原型
- PyTorch 2.1.0及以上版本:
npu_grouped_matmul([torch.Tensor] x, [torch.Tensor] weight, [torch.Tensor] bias=None, [torch.Tensor] scale=None, ([torch.Tensor] offset=None, ([torch.Tensor] antiquant_scale=None, [Tensor] antiquant_offset=None, [torch.int] group_list=None, int split_item=0, torch.dtype output_dtype=None) -> [torch.Tensor]
- PyTorch 1.11.0与2.0.1版本:
npu_grouped_matmul([torch.Tensor] x, [torch.Tensor] weight, [torch.Tensor] bias, [torch.Tensor] scale, [torch.Tensor] offset, [torch.Tensor] antiquant_scale, [torch.Tensor] antiquant_offset, [int] group_list=None, int split_item=0, torch.dtype output_dtype=None) -> [torch.Tensor]
参数说明
- x (List[torch.Tensor]):输入矩阵列表,表示矩阵乘法中的左矩阵。数据类型支持torch.float16、torch.bfloat16和torch.int8。
- 列表最大长度为128。
- 当split_item=0时,张量支持2至6维输入;其他情况下,张量仅支持2维输入。
- weight (List[torch.Tensor]):权重矩阵列表,表示矩阵乘法中的右矩阵。数据类型支持torch.float16、torch.bfloat16和torch.int8。
- 列表最大长度为128。
- 每个张量仅支持2维输入。
- bias (List[torch.Tensor]):每个分组的矩阵乘法输出的独立偏置项。数据类型支持torch.float16、torch.float32和torch.int32。
- 列表最大长度为128。
- 每个张量仅支持1维输入。
- scale (List[torch.Tensor]):用于缩放原数值以匹配量化后的范围值,代表量化参数中的缩放因子,对应公式(2)。数据类型支持torch.int64。
- 列表长度与weight列表长度相同。
- 每个张量仅支持1维输入。
- offset (List[torch.Tensor]):用于调整量化后的数值偏移量,从而更准确地表示原始浮点数值,对应公式(2)。当前仅支持传入None。
- antiquant_scale (List[torch.Tensor]):用于缩放原数值以匹配伪量化后的范围值,代表伪量化参数中的缩放因子,对应公式(3)。数据类型支持torch.float16、torch.bfloat16 。
- 列表长度与weight列表长度相同。
- 每个张量仅支持1维输入。
- antiquant_offset (List[torch.Tensor]):用于调整伪量化后的数值偏移量,从而更准确地表示原始浮点数值,对应公式(3)。数据类型支持torch.float16、torch.bfloat16。
- 列表长度与weight列表长度相同。
- 每个张量仅支持1维输入。
- group_list (List[int]):用于指定分组的索引,表示x的第0维矩阵乘法的索引情况。数据类型支持torch.int64。
- 输入支持1维,最大长度为128。
- 配置值必须为非负递增数列,且列表长度不能为1。
- split_item (int):用于指定切分模式。数据类型支持torch.int32。
- 0/1:输出为多个张量,数量与weight相同。
- 2/3:输出为单个张量。
- output_dtype (torch.dtype):输出数据类型。支持的配置包括:
- None:默认值,表示输出数据类型与输入x的数据类型相同。
- 与输入x数据类型一致的类型。
返回值
List[torch.Tensor]:
- 当split_item为0或1时,返回的张量数量与weight相同。
- 当split_item为2或3时,返回的张量数量为1。
约束说明
- 该API仅在推理场景下使用;
- 该API仅在PyTorch 2.1.0版本下支持图模式;
- 各场景输入与输出数据类型使用约束:
表1 数据类型约束 场景
x
weight
bias
scale
antiquant_scale
antiquant_offset
output_dtype
y
非量化
torch.float16
torch.float16
torch.float16
无需赋值
无需赋值
无需赋值
torch.float16
torch.float16
torch.bfloat16
torch.bfloat16
torch.float32
无需赋值
无需赋值
无需赋值
torch.bfloat16
torch.bfloat16
量化
torch.int8
torch.int8
torch.int32
torch.int64
无需赋值
无需赋值
torch.int8
torch.int8
伪量化
torch.float16
torch.int8
torch.float16
无需赋值
torch.float16
torch.float16
torch.float16
torch.float16
torch.bfloat16
torch.int8
torch.float32
无需赋值
torch.bfloat16
torch.bfloat16
torch.bfloat16
torch.bfloat16
- 根据输入x、权重weight和输出y的张量数量不同,支持以下四种场景。场景中的“单”表示单个张量,“多”表示多个张量。场景顺序为x、weight、y,例如“单多单”表示x为单张量,weight为多张量,y为单张量。
表2 各场景使用限制 场景
说明
场景限制
多多多
x和weight为多张量,y为多张量。每组数据的张量是独立的。
- split_item仅支持0/1。
- x中张量支持2-6维,weight中张量需为2维,y中张量维度与x保持一致。
- x 中张量大于2维时,group_list必须为空。
- x为2维且传递group_list时,group_list的差值需与x中张量的第一维一一对应。
单多单
x为单张量,weight为多张量,y为单张量。
- split_item仅支持2/3。
- 必须传递group_list,且最后一个值需与x中张量的第一维相等。
- x、weight和y中张量需为2维。
- weight中每个张量的N轴必须相等。
单多多
x为单张量,weight为多张量,y为多张量。
- split_item仅支持0/1。
- 必须传递group_list,group_list的差值需与y中张量的第一维一一对应。
- x、weight和y中张量需为2维。
多多单
x和weight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。
- split_item仅支持2/3。
- x、weight和y中张量需为2维。
- weight中每个张量的N轴必须相等。
- 若传递group_list,其差值需与x中张量的第一维一一对应。
支持的型号
调用示例
- 单算子模式调用
- PyTorch 2.1.0及以上的版本
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import torch import torch_npu x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] bias1 = torch.randn(256, device='npu', dtype=torch.float16) bias2 = torch.randn(1024, device='npu', dtype=torch.float16) bias3 = torch.randn(128, device='npu', dtype=torch.float16) bias = [bias1, bias2, bias3] group_list = None split_item = 0 npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, group_list=group_list, split_item=split_item)
- PyTorch 1.11.0与PyTorch 2.0.1版本
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import torch import torch_npu x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] bias1 = torch.randn(256, device='npu', dtype=torch.float16) bias2 = torch.randn(1024, device='npu', dtype=torch.float16) bias3 = torch.randn(128, device='npu', dtype=torch.float16) bias = [bias1, bias2, bias3] group_list = None split_item = 0 npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, scale=[], offset=[], antiquant_scale=[], antiquant_offset=[], group_list=group_list, split_item=split_item)
- PyTorch 2.1.0及以上的版本
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class GMMModel(nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return torch_npu.npu_grouped_matmul(x, weight) def main(): x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] model = GMMModel().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) custom_output = model(x, weight) if __name__ == '__main__': main()
父主题: torch_npu