vcmp
功能说明
lhs tensor中的元素与rhs按operation指定的操作进行比较,operation指定的操作包括eq、ne、lt、gt、le、ge,依次表示==、!=、<、>、<=、>=。如果表达式成立,mode为bool时返回True,mode为bit时返回1;如果表达式不成立,mode为bool时返回False,mode为bit时返回0。
以下通过表达式的方式来解释各操作的含义,其中,x表示lhs中的元素,y表示rhs中的元素,z表示结果tensor的元素,n(取值0~7)表示结果tensor的元素的bit位index,表达式如下:
- mode=='bool':
- lt: z = True (x < y) or False (x >= y)
- gt: z = True (x > y) or False (x <= y)
- le: z = True (x <= y) or False (x > y)
- ge: z = True (x >= y) or False (x < y)
- eq: z = True (x == y) or False (x != y)
- ne: z = True (x != y) or False (x == y)
- mode=='bit':
- lt: z[n] = 1 (x < y) or 0 (x >= y)
- gt: z[n] = 1 (x > y) or 0 (x <= y)
- le: z[n] = 1 (x <= y) or 0 (x > y)
- ge: z[n] = 1 (x >= y) or 0 (x < y)
- eq: z[n] = 1 (x == y) or 0 (x != y)
- ne: z[n] = 1 (x != y) or 0 (x == y)
函数原型
vcmp(lhs, rhs, operation='lt', mode='bool')
参数说明
- lhs:进行比较操作的左操作数,tvm.tensor类型。
- rhs:进行比较操作的右操作数,tvm.tensor或者 scalar类型。
- operation : 操作类型,支持eq, ne, lt, gt, ge, le,默认为lt。
- mode:模式类型,支持bool和bit,默认为bool。
其中,进行比较操作的左操作数和右操作数的数据类型需要一致。
- Atlas 200/300/500 推理产品:支持的数据类型有float16
- Atlas 训练系列产品:支持的数据类型有float16、float32、int64
- Atlas 推理系列产品(Ascend 310P处理器):支持的数据类型有float16、float32、int64、uint64
- Atlas 200I/500 A2推理产品:支持的数据类型有float16、float32、int64
- Atlas A2训练系列产品/Atlas 800I A2推理产品:支持的数据类型有float16、float32、int64
返回值
res_tensor:结果tensor,tvm.tensor类型。mode为bool时数据类型为bool,mode为bit时数据类型为uint8。
约束说明
- 当mode为bool时,调用tbe.dsl.build接口进行编译时,传入的config参数中需要设置bool_storage_as_1bit为False。否则会出现某些shape的输出结果不符合预期。
因为bool_storage_as_1bit的默认值为True,表示按照1 bit进行存储。
编译配置config如下所示:
with tvm.target.cce(): schedule = tbe.dsl.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [data_x, data_y, res], "bool_storage_as_1bit": False} tbe.dsl.build(schedule, config)
- 当mode为bit时,左操作数的shape的最后一维必须被8整除。
- 如果右操作数也为tensor,那么两个tensor的shape需要相同。
- 针对Atlas 推理系列产品(Ascend 310P处理器),当数据类型为int64与uint64时,仅支持mode为bool模式。
支持的型号
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas 推理系列产品(Ascend 310P处理器)
Atlas 200I/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
from tbe import tvm from tbe import dsl shape = (1024,1024) input_dtype = "float16" data1 = tvm.placeholder(shape, name="data1", dtype=input_dtype) data2 = tvm.placeholder(shape, name="data2", dtype=input_dtype) res = dsl.vcmp(data1, data2, 'lt', 'bit')
父主题: Math计算接口