下载
中文
注册

LinearQAT

功能说明

构造Linear的QAT算子。

函数原型

直接构造接口:

amct_pytorch.nn.module.quantization.linear.LinearQAT(in_features, out_features, bias, device, dtype, config)

基于原生算子构造接口:

amct_pytorch.nn.module.quantization.linear.LinearQAT.from_float(mod, config)

参数说明

表1 直接构造接口参数说明

参数名

输入/输出

含义

使用限制

in_features

输入

输入特征数

数据类型:int

必填

out_features

输入

输出特征数

数据类型:int

必填

bias

输入

是否开启偏值项参与学习

数据类型:bool

默认值为True

device

输入

运行设备

默认值:None

dtype

输入

torch数值类型

torch数据类型, 仅支持torch.float32

config

输入

量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

表2 基于原生算子构造接口

参数名

输入/输出

含义

使用限制

mod

输入

待量化的原生Linear算子

数据类型:torch.nn.Module

config

输入

量化配置。

配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

返回值说明

生成一个Linear对应QAT算子,用于后续量化感知训练。

调用示例

直接构造:

1
2
3
4
from amct_pytorch.nn.module.quantization.linear import LinearQAT

LinearQAT(in_features=1, out_features=1, bias=True,
          device=None, dtype=None, config=None)

基于原生算子构造:

1
2
3
4
5
6
import torch

from amct_pytorch.nn.module.quantization.linear import LinearQAT

linear_op = torch.nn.Linear(in_features=1, out_features=1, bias=True, device=None, dtype=None)
LinearQAT.from_float(mod=linear_op, config=None)