下载
中文
注册

量化感知训练场景下,出现数据类型不一致的解决方法

问题描述

混合精度训练场景下,用户做量化感知训练,调用create_quant_retrain_model接口之后,脚本提示如下信息:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
case_info: ['AMCT7_MINDSPORE_CREATE_QUANT_CONFIG_037', 'cascade_conv', '1,10,32,32', 10, 10, '3,3', '1,1', 'same', '0,0,0,0', '1,1', 1, 1, 0, '-1,1,-1,1,-1,1', 'uniform,uniform,uniform']
in_channels: 10 out_channels: 10 kernel_size: [3, 3] stride: (1, 1) pad_mode: same padding: (0, 0, 0, 0) dilation: (1, 1) group: 1 has_bias: True uniform uniform 0 [-1.0, 1.0, -1.0, 1.0] [1, 10, 32, 32] cascade_conv
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
case_info: ['AMCT7_MINDSPORE_CREATE_QUANT_CONFIG_037', 'cascade_conv', '1,10,32,32', 10, 10, '3,3', '1,1', 'same', '0,0,0,0', '1,1', 1, 1, 0, '-1,1,-1,1,-1,1', 'uniform,uniform,uniform']
in_channels: 10 out_channels: 10 kernel_size: [3, 3] stride: (1, 1) pad_mode: same padding: (0, 0, 0, 0) dilation: (1, 1) group: 1 has_bias: True uniform uniform 0 [-1.0, 1.0, -1.0, 1.0] [1, 10, 32, 32] cascade_conv
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
(-1.0, 1.0)
For 'Conv2D', type of `w` should be same as `x`, but `x` with type Tensor[Float16] and `w` with type Ref[Tensor(F32)].

可能原因

调用create_quant_retrain_model接口之后,用户原始网络中的nn.Dense会被替换为DenseQatBlock对象,nn.Conv2D被替换为Conv2dQatBlock对象。在调用量化感知训练脚本时,需要先导入上述两个对象。

解决方法

修改用户量化感知训练的脚本,先导入上述两个对象。

如果用户原始脚本为:

1
2
3
4
5
6
def switch_precision(net, data_type, config):
    if config.platform == "Ascend":
        net.to_float(data_type)
        for _, cell in net.cells_and_names():
            if isinstance(cell, (nn.Dense, )):
                cell.to_float(mstype.float32)

需要修改为:

1
2
3
4
5
6
7
8
from amct_mindspore.cells.dense_qat_cell import DenseQatBlock
from amct_mindspore.cells.conv2d_qat_cell import Conv2dQatBlock
def switch_precision(net, data_type, config):
    if config.platform == "Ascend":
        net.to_float(data_type)
        for _, cell in net.cells_and_names():
            if isinstance(cell, (nn.Dense, DenseQatBlock)):
                cell.to_float(mstype.float32)