文档
注册

Conv2dQAT

功能说明

构造Conv2d的QAT算子。

函数原型

直接构造接口:

amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, config)

基于原生算子构造接口:

amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT.from_float(mod, config)

参数说明

表1 直接构造接口参数说明

参数名

输入/输出

含义

使用限制

in_channels

输入

输入channel个数

数据类型:int

必填

out_channels

输入

输出channel个数

数据类型:int

必填

kernel_size

输入

卷积核大小

数据类型:int/tuple

必填

stride

输入

卷积步长

数据类型:int/tuple

默认值为1

padding

输入

填充大小

数据类型:int/tuple

默认值为0

dilation

输入

kernel元素之间的间距

数据类型:int/tuple

默认值为1

groups

输入

输入和输出的连接关系

数据类型:int

默认值为1

bias

输入

是否开启偏值项参与学习

数据类型:bool

默认值为True

padding_mode

输入

填充方式

仅支持zeros

device

输入

运行设备

默认值:None

dtype

输入

torch数值类型

torch数据类型, 仅支持torch.float32

config

输入

量化配置。

配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

表2 基于原生算子构造接口

参数名

输入/输出

含义

使用限制

mod

输入

待量化的原生Conv2d算子

数据类型:torch.nn.Module

config

输入

量化配置。

配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

返回值说明

生成一个Conv2d对应QAT算子,用于后续量化感知训练。

调用示例

直接构造:

from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT

Conv2dQAT(in_channels=1, out_channels=1, kernel_size=1, stride=1,
          padding=0, dilation=1, group=1, bias=True,
          padding_mode='zeros', device=None, dtype=None, config=None)

基于原生算子构造:

import torch

from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT

conv2d_op = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1,
                            padding=0, dilation=1, group=1, bias=True,
                            padding_mode='zeros', device=None, dtype=None)
Conv2dQAT.from_float(mod=conv2d_op, config=None)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词