下载
中文
注册

torch_npu.npu_quant_scatter_

功能描述

先将updates进行量化,然后将updates中的值按指定的轴axis和索引indices更新self中的值,self中的数据被改变。

接口原型

1
torch_npu.npu_quant_scatter_(Tensor(a!) self, Tensor indices, Tensor updates, Tensor quant_scales, Tensor? quant_zero_points=None, int axis=0, int quant_axis=1, str reduce='update') -> Tensor(a!)

参数说明

  • self:Device侧的Tensor类型,必选输入,源数据张量,数据类型支持INT8,数据格式支持ND,支持非连续的Tensor,维数只能是3~8维。
  • indices:Device侧的Tensor类型,必选输入,索引张量,数据类型支持INT32,数据格式支持ND,支持非连续的Tensor。
  • updates:Device侧的Tensor类型,必选输入,更新数据张量,数据格式支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持FLOAT16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持BFLOAT16、FLOAT16。
  • quant_scales:Device侧的Tensor类型,必选输入,量化缩放张量,数据格式支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持FLOAT32。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持BFLOAT16、FLOAT32。
  • quant_zero_points:Device侧的Tensor类型,可选输入,量化偏移张量,数据类型支持BFLOAT16、INT32,数据格式支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持INT32。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持BFLOAT16、INT32。
  • axis:Host侧的int类型,可选参数,updates上用来更新的轴,默认值为0。
  • quant_axis:Host侧的int类型,可选参数,updates上用来量化的轴,默认值为1。
  • reduce:Host侧的str类型,可选参数,表示数据操作方式;当前只支持‘update’,即更新操作。

输出说明

返回被更新后的self。

约束说明

  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • indices的维数只能是1维或者2维;如果是2维,其第2维的大小必须是2;不支持索引越界,索引越界不校验;indices映射的self数据段不能重合,若重合则会因为多核并发原因导致多次执行结果不一样。
  • updates的维数需要与self的维数一样;其第1维的大小等于indices的第1维的大小,且不大于self的第1维的大小;其axis轴的大小不大于self的axis轴的大小;其余维度的大小要跟self对应维度的大小相等;其最后一维的大小必须32B对齐。
  • quant_scales的元素个数需要等于updates在quant_axis轴的大小。
  • quant_zero_points的元素个数需要等于updates在quant_axis轴的大小。
  • axis不能为updates的第1维或最后1维。
  • quant_axis只能为updates的最后1维。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas 推理系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    import torch
    import torch_npu
    import numpy as np
    
    data_var = np.random.uniform(0, 1, [24, 4096, 128]).astype(np.int8)
    var = torch.from_numpy(data_var).to(torch.int8).npu()
    
    data_indices = np.random.uniform(0, 1, [24]).astype(np.int32)
    indices = torch.from_numpy(data_indices).to(torch.int32).npu()
    
    data_updates = np.random.uniform(1, 2, [24, 1, 128]).astype(np.float16)
    updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
    
    data_quant_scales = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
    quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
    
    data_quant_zero_points = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
    quant_zero_points = torch.from_numpy(data_quant_zero_points).to(torch.bfloat16).npu()
    
    axis = -2
    quant_axis = -1
    reduce = "update"
    
    torch_npu.npu_quant_scatter_(var, indices, updates, quant_scales, quant_zero_points, axis=axis, quant_axis=quant_axis, reduce=reduce)
    
  • 图模式调用
      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    # 入图方式
    import torch
    import torch_npu
    import math
    import torchair as tng
    import tensorflow as tf
    import numpy as np
    from torchair.configs.compiler_config import CompilerConfig
    import torch._dynamo
    TORCHDYNAMO_VERBOSE=1
    TORCH_LOGS="+dynamo"
    
    # 支持入图的打印宏
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    from torch.library import Library, impl
    
    # 数据生成
    np_bf16 = tf.bfloat16.as_numpy_dtype
    dtype_list2 =["fp16","int8","int32","fp32","fp16"]
    dtype_list =[np.float16,np.int8,np.int32,np.float32]
    updates_shape =[1,11,1,32]
    var_shape =[1,11,1,32]
    indices_shape =[1,2]
    quant_scales_shape =[1,1,1,32]
    quant_zero_points_shape =[1,1,1,32]
    axis =-2
    quant_axis=-1
    reduce = "update"
    updates_data = np.random.uniform(-1,1,size=updates_shape)
    var_data = np.random.uniform(1,2,size=var_shape).astype(dtype_list[1])
    quant_scales_data = np.random.uniform(1,2,size=quant_scales_shape)
    indices_data = np.random.uniform(0,1,size=indices_shape).astype(dtype_list[2])
    quant_zero_points_data = np.random.uniform(0,1,size=quant_zero_points_shape)
    updates_npu = torch.from_numpy(updates_data).npu().to(torch.bfloat16).npu()
    var_npu = torch.from_numpy(var_data).npu()
    quant_scales_npu = torch.from_numpy(quant_scales_data).npu().to(torch.bfloat16).npu()
    quant_zero_points_npu = torch.from_numpy(quant_zero_points_data).to(torch.bfloat16).npu()
    indices_npu = torch.from_numpy(indices_data).npu()
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self):
            return torch_npu.npu_quant_scatter_(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
    def MetaInfershape():
        with torch.no_grad():
            model = Model()
            model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
            graph_output = model()
        single_op = torch_npu.npu_quant_scatter_(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
        print("single op output with mask:", single_op[0], single_op[0].shape)
        print("graph output with mask:", graph_output[0], graph_output[0].shape)
    if __name__ == "__main__":
        MetaInfershape()
    
    # 执行上述代码的输出类似如下
    single op output with mask: tensor([[[ 0,  0,  1,  1,  1,  0,  1,  0,  1,  1,  0,  0,  0,  1,  0,  1,  0,
               1,  1,  1,  0,  0,  0,  0,  0,  1,  1,  1,  0,  1,  1,  1]],
            [[ 0,  0,  1,  0,  1,  0,  0,  1,  0,  0,  1,  0,  0,  1,  1,  1,  0,
               1,  1,  0,  1,  1,  0,  0, -1,  0,  1,  0,  1,  0,  1,  0]],
            [[ 0,  1,  1,  1,  1,  1,  1,  1,  0,  1,  0,  0,  1,  1,  1,  1,  0,
               1,  0,  1,  0,  1,  1,  0,  0,  0,  0,  0,  1,  1,  1,  1]],
            [[ 0,  0,  1,  1,  1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  0,  1,
               1,  0,  1,  1,  1,  1,  1,  1,  1,  0,  0,  1,  0,  0,  1]],
            [[ 0,  0,  1,  1,  1,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  2,  0,
               1,  1,  0,  1,  1,  1,  1, -1,  0,  0,  0,  1,  1,  1,  0]],
            [[ 0,  1,  1,  0,  1,  0,  0,  1,  0,  1,  0,  1,  1,  0,  1,  1,  0,
               1,  1,  1,  0,  0,  1,  0, -1,  0,  0,  0,  1,  1,  1,  0]],
            [[ 0, -1,  1,  1,  1,  0,  0,  1,  1,  0,  0,  1,  0,  1,  2,  1,  0,
               1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  0,  0,  0,  0]],
            [[ 1,  0,  0,  1,  1,  0,  1,  0,  0,  1,  0,  0,  0,  2,  0,  1,  0,
               1,  1,  1,  0,  1,  0,  0,  1,  0,  0,  0,  1,  1,  1,  1]],
            [[ 0,  0,  1,  0,  1,  1,  0,  1,  0,  1,  0,  0,  1,  2,  1,  1,  0,
               0,  0,  0,  0,  0,  0,  0,  0,  1,  0,  1,  1,  0,  1,  0]],
            [[ 0,  0,  0,  1,  1,  0,  0,  1,  0,  0,  0, -1,  0,  1,  1,  0,  1,
               1,  1,  1,  1,  1,  0,  0,  0,  1,  0,  0,  1,  1,  1,  0]],
            [[ 0,  1,  0,  0,  1,  0,  1,  0,  0,  1,  1,  0,  1,  1,  1,  1,  0,
               1,  1,  0,  1,  1,  0,  0,  0,  0,  0,  0,  1,  0,  1,  1]]],
           device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
    graph output with mask: tensor([[[ 0,  0,  1,  1,  1,  0,  1,  0,  1,  1,  0,  0,  0,  1,  0,  1,  0,
               1,  1,  1,  0,  0,  0,  0,  0,  1,  1,  1,  0,  1,  1,  1]],
            [[ 0,  0,  1,  0,  1,  0,  0,  1,  0,  0,  1,  0,  0,  1,  1,  1,  0,
               1,  1,  0,  1,  1,  0,  0, -1,  0,  1,  0,  1,  0,  1,  0]],
            [[ 0,  1,  1,  1,  1,  1,  1,  1,  0,  1,  0,  0,  1,  1,  1,  1,  0,
               1,  0,  1,  0,  1,  1,  0,  0,  0,  0,  0,  1,  1,  1,  1]],
            [[ 0,  0,  1,  1,  1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  0,  1,
               1,  0,  1,  1,  1,  1,  1,  1,  1,  0,  0,  1,  0,  0,  1]],
            [[ 0,  0,  1,  1,  1,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  2,  0,
               1,  1,  0,  1,  1,  1,  1, -1,  0,  0,  0,  1,  1,  1,  0]],
            [[ 0,  1,  1,  0,  1,  0,  0,  1,  0,  1,  0,  1,  1,  0,  1,  1,  0,
               1,  1,  1,  0,  0,  1,  0, -1,  0,  0,  0,  1,  1,  1,  0]],
            [[ 0, -1,  1,  1,  1,  0,  0,  1,  1,  0,  0,  1,  0,  1,  2,  1,  0,
               1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  0,  0,  0,  0]],
            [[ 1,  0,  0,  1,  1,  0,  1,  0,  0,  1,  0,  0,  0,  2,  0,  1,  0,
               1,  1,  1,  0,  1,  0,  0,  1,  0,  0,  0,  1,  1,  1,  1]],
            [[ 0,  0,  1,  0,  1,  1,  0,  1,  0,  1,  0,  0,  1,  2,  1,  1,  0,
               0,  0,  0,  0,  0,  0,  0,  0,  1,  0,  1,  1,  0,  1,  0]],
            [[ 0,  0,  0,  1,  1,  0,  0,  1,  0,  0,  0, -1,  0,  1,  1,  0,  1,
               1,  1,  1,  1,  1,  0,  0,  0,  1,  0,  0,  1,  1,  1,  0]],
            [[ 0,  1,  0,  0,  1,  0,  1,  0,  0,  1,  1,  0,  1,  1,  1,  1,  0,
               1,  1,  0,  1,  1,  0,  0,  0,  0,  0,  0,  1,  0,  1,  1]]],
           device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])