数据类型转换
方法介绍
由于float16数据类型本身存在精度限制,详情请参见Link,所以在进行算子实现时,可将入参中float16类型的数据转换成float32类型进行计算,用高精度数据进行中间计算来提升精度,从而提升最终结果的精度。当中间计算过程较为复杂时,此方法精度提升效果比较明显。
示例:实现res=x * y。
在算子计算之前,若输入数据x、y为fp16的数据类型,则将其转换为fp32类型,计算完成后,将结果数据res转换成原始数据类型fp16,算子代码示例如下所示:
dtype = data_x.dtype if dtype == "float16": data_x = tbe.dsl.cast_to(data_x, "float32") data_y = tbe.dsl.cast_to(data_y, "float32") res = tbe.dsl.vmul(data_x, data_y) if dtype == "float16": res = tbe.dsl.cast_to(res, "float16")
同时,在进行精度验证的时候,对比数据也要转换成fp32计算。
注意事项
进行数据类型的转换需要注意以下两点:
- 在最后输出结果时要将数据类型转换成原数据类型。
- float16转成float32计算会导致运行性能降低,因此如果使用float16的数据类型进行计算的精度在可允许范围内,尽量不要转换数据类型。
父主题: 精度优化方法