Conv2dQAT
功能说明
构造Conv2d的QAT算子。
函数原型
直接构造接口:
amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, config)
基于原生算子构造接口:
amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT.from_float(mod, config)
参数说明
参数名 |
输入/输出 |
含义 |
使用限制 |
---|---|---|---|
in_channels |
输入 |
输入channel个数 |
数据类型:int 必填 |
out_channels |
输入 |
输出channel个数 |
数据类型:int 必填 |
kernel_size |
输入 |
卷积核大小 |
数据类型:int/tuple 必填 |
stride |
输入 |
卷积步长 |
数据类型:int/tuple 默认值为1 |
padding |
输入 |
填充大小 |
数据类型:int/tuple 默认值为0 |
dilation |
输入 |
kernel元素之间的间距 |
数据类型:int/tuple 默认值为1 |
groups |
输入 |
输入和输出的连接关系 |
数据类型:int 默认值为1 |
bias |
输入 |
是否开启偏值项参与学习 |
数据类型:bool 默认值为True |
padding_mode |
输入 |
填充方式 |
仅支持zeros |
device |
输入 |
运行设备 |
默认值:None |
dtype |
输入 |
torch数值类型 |
torch数据类型, 仅支持torch.float32 |
config |
输入 |
量化配置。 配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = { "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": -1.0, "clip_max": 1.0 }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": False } } |
数据类型:dict 默认值:None |
参数名 |
输入/输出 |
含义 |
使用限制 |
---|---|---|---|
mod |
输入 |
待量化的原生Conv2d算子 |
数据类型:torch.nn.Module |
config |
输入 |
量化配置。 配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = { "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": -1.0, "clip_max": 1.0 }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": False } } |
数据类型:dict 默认值:None |
返回值说明
生成一个Conv2d对应QAT算子,用于后续量化感知训练。
调用示例
直接构造:
1 2 3 4 5 |
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT Conv2dQAT(in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, config=None) |
基于原生算子构造:
1 2 3 4 5 6 7 8 |
import torch from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT conv2d_op = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) Conv2dQAT.from_float(mod=conv2d_op, config=None) |