将数据类型为int32的输入tensor打包为int4存放,每8个int4数据通过一个int32数据承载,并进行交叠排放。
torch_npu.npu_convert_weight_to_int4pack(Tensor weight, int inner_k_tiles=0) -> Tensor
输出为Tensor类型,代表int4打包后的输出,数据类型为INT32,shape为(k, n/8), (n, k/8),数据格式支持ND。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import torch import torch_npu m = 128 k = 64 n = 32 trans_weight = False cpu_x = torch.randn((m, k), dtype=torch.float16) if trans_weight: cpu_weight = torch.randint(low=-8, high=8, size=(n, k), dtype=torch.int32) cpu_antiquantscale = torch.randn((n, 1), dtype=torch.float16) cpu_antiquantoffset = torch.randn((n, 1), dtype=torch.float16) else: cpu_weight = torch.randint(low=-8, high=8, size=(k, n), dtype=torch.int32) cpu_antiquantscale = torch.randn((1, n), dtype=torch.float16) cpu_antiquantoffset = torch.randn((1, n), dtype=torch.float16) weight_int4 = torch_npu.npu_convert_weight_to_int4pack(cpu_weight.npu()) if trans_weight: cpu_weight = cpu_weight.transpose(-1, -2) weight_int4 = weight_int4.transpose(-1, -2) cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2) cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), weight_int4.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu()) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | import torch import torch_npu import torchair from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = torchair.get_npu_backend(compiler_config=config) m = 16 k = 17 n = 72 trans_weight = False is_weight_nz = False cpu_x = torch.randn((m, k),dtype=torch.float16) if trans_weight: cpu_weight = torch.randint(low=-8, high=8, size=(n, k) ,dtype=torch.int32) cpu_antiquantscale = torch.ones((n, 1),dtype=torch.float16) cpu_antiquantoffset = torch.zeros((n, 1),dtype=torch.float16) else: cpu_weight = torch.randint(low=-8, high=8, size=(k, n) ,dtype=torch.int32) cpu_antiquantscale = torch.ones((1, n),dtype=torch.float16) cpu_antiquantoffset = torch.zeros((1, n),dtype=torch.float16) npu_weight = cpu_weight.npu() if is_weight_nz: # nd to fractal_nz npu_weight = torch_npu.npu_format_cast(npu_weight.npu(), 29) # int32 to int4pack weight_int4pack = torch_npu.npu_convert_weight_to_int4pack(npu_weight) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size): if trans_weight: weight = weight.transpose(-1, -2) antiquant_scale = antiquant_scale.transpose(-1, -2) antiquant_offset = antiquant_offset.transpose(-1, -2) return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size) cpu_model = MyModel() model = cpu_model.npu() model = torch.compile(cpu_model, backend=npu_backend, dynamic=True, fullgraph=True) npu_out = model(cpu_x.npu(), weight_int4pack, cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0) |