torch_npu.npu_dynamic_quant_asymmetric
功能描述
- 算子功能:对输入的张量进行per-token非对称动态量化。其中输入的最后一个维度对应一个token,每个token作为一组进行量化。
- 计算公式:假设待量化张量为x,
- rowMax、rowMin代表按行取最大值、按行取最小值,此处的“行”对应x最后一个维度的数据,即一个token。
- DST_MAX、DST_MIN分别对应量化后的最大值和最小值,在进行INT8量化时,二者分别对应+127、-128,进行INT4量化时,分别对应+7、-8
- 若使用smooth quant算法,会引入smooth_scales输入,其形状与x最后一个维度大小一致,在进行量化前,会先令x乘以smooth_scales,再按上述公式进行量化
- 若使用smooth quant算法,且在MOE(混合专家模型)场景下,会引入smooth_scales输入和group_index输入,此时smooth_scales中包含多组smooth向量,按group_index中的数值作用到x的不同行上。具体的,假如x包含m个token,smooth_scales有n行,smooth_scales[0]会作用到x[0:group_index[0]]上,smooth_scales[i]会作用到x[group_index[i-1]: group_index[i]]上,i=1,2,...,n-1。
接口原型
torch_npu.npu_dynamic_quant_asymmetric(Tensor x, *, Tensor? smooth_scales=None, Tensor? group_index=None, ScalarType? dst_type=None) -> (Tensor, Tensor, Tensor)
参数说明
- x:Device侧的Tensor类型,需要进行量化的源数据张量,必选输入,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,支持非连续的Tensor。输入x的维度必须大于1。进行INT4量化时,要求x形状的最后一维是8的整数倍。
- smooth_scales:Device侧的Tensor类型,对x进行平滑缩放的张量,可选输入,数据类型需要与x保持一致,数据格式支持ND,支持非连续的Tensor。
- group_index:Device侧的Tensor类型,在MOE场景下,对smooth_scales进行分组的下标,可选输入,数据类型支持INT32,数据格式支持ND,支持非连续的Tensor。
- dst_type:ScalarType类型,用于选择进行INT8/INT4量化,可选输入,输入值只能是torch.int8和torch.quint4x2,默认为INT8量化。
输出说明
该接口包含三个输出,依次是y、scale、offset,含义如下
- y:量化后的输出Tensor,在进行INT8量化时,y的数据类型为INT8,形状与x一致;在进行INT4量化时,y的数据类型为INT32,形状最后一维为x最后一维除以8,其余维度与x一致,每个INT32元素包含8个INT4结果。
- scale:非对称动态量化过程中计算出的缩放系数Tensor,数据类型为FLOAT32,形状为x的形状剔除最后一维。
- offset:非对称动态量化过程中计算出的偏移系数Tensor,数据类型为FLOAT32,形状为x的形状剔除最后一维。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- 使用可选输入smooth_scales、group_index、dst_type时,必须使用关键字传参。
- 使用smooth_scales时:
- 若不使用group_index,smooth_scales必须是一维Tensor,元素数量与x的最后一维大小一致。
- 若使用group_index,smooth_scales必须是二维Tensor,第二维元素数量与x的最后一维大小一致,group_index必须是一维数组,元素数量与smooth_scales第一维一致。group_index中的元素必须是单调递增的,其最后一个元素的值,应等于x的元素数量除以x的最后一个维度。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 Atlas A3 训练系列产品
调用示例
- 单算子模式调用
- 只有一个输入x,进行INT8量化
1 2 3 4 5
import torch import torch_npu x = torch.rand((3, 8), dtype=torch.half).npu() y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x) print(y, scale, offset)
- 只有一个输入x,进行INT4量化
1 2 3 4 5
import torch import torch_npu x = torch.rand((3, 8), dtype=torch.half).npu() y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x, dst_type=torch.quint4x2) print(y, scale, offset)
- 使用smooth_scales输入,非MOE场景(不使用group_index),进行INT8量化
1 2 3 4 5 6
import torch import torch_npu x = torch.rand((3, 8), dtype=torch.half).npu() smooth_scales = torch.rand((8,), dtype=torch.half).npu() y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x, smooth_scales=smooth_scales) print(y, scale, offset)
- 使用smooth_scales输入,MOE场景(使用group_index),进行INT8量化
1 2 3 4 5 6 7
import torch import torch_npu x = torch.rand((3, 8), dtype=torch.half).npu() smooth_scales = torch.rand((2, 8), dtype=torch.half).npu() group_index = torch.Tensor([1, 3]).to(torch.int32).npu() y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x, smooth_scales=smooth_scales, group_index=group_index) print(y, scale, offset)
- 只有一个输入x,进行INT8量化
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
import torch import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig torch_npu.npu.set_compile_mode(jit_compile=True) config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) device=torch.device(f'npu:4') torch_npu.npu.set_device(device) class DynamicQuantModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensor, smooth_scales=None, group_index=None, dst_type=None): out, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(input_tensor, smooth_scales=smooth_scales, group_index=group_index, dst_type=dst_type) return out, scale, offset x = torch.randn((2, 4, 6),device='npu',dtype=torch.float16).npu() smooth_scales = torch.randn((6),device='npu',dtype=torch.float16).npu() dynamic_quant_model = DynamicQuantModel().npu() dynamic_quant_model = torch.compile(dynamic_quant_model, backend=npu_backend, dynamic=True) out, scale, offset = dynamic_quant_model(x, smooth_scales=smooth_scales) print(out) print(scale) print(offset)
父主题: torch_npu