文档
注册

vcmp

功能说明

lhs tensor中的元素与rhs按operation指定的操作进行比较,operation指定的操作包括eq、ne、lt、gt、le、ge,依次表示==、!=、<、>、<=、>=。如果表达式成立,mode为bool时返回True,mode为bit时返回1;如果表达式不成立,mode为bool时返回False,mode为bit时返回0。

以下通过表达式的方式来解释各操作的含义,其中,x表示lhs中的元素,y表示rhs中的元素,z表示结果tensor的元素,n(取值0~7)表示结果tensor的元素的bit位index,表达式如下:

  • mode=='bool':
    • lt: z = True (x < y) or False (x >= y)
    • gt: z = True (x > y) or False (x <= y)
    • le: z = True (x <= y) or False (x > y)
    • ge: z = True (x >= y) or False (x < y)
    • eq: z = True (x == y) or False (x != y)
    • ne: z = True (x != y) or False (x == y)
  • mode=='bit':
    • lt: z[n] = 1 (x < y) or 0 (x >= y)
    • gt: z[n] = 1 (x > y) or 0 (x <= y)
    • le: z[n] = 1 (x <= y) or 0 (x > y)
    • ge: z[n] = 1 (x >= y) or 0 (x < y)
    • eq: z[n] = 1 (x == y) or 0 (x != y)
    • ne: z[n] = 1 (x != y) or 0 (x == y)

函数原型

vcmp(lhs, rhs, operation='lt', mode='bool')

参数说明

  • lhs:进行比较操作的左操作数,tvm.tensor类型。
  • rhs:进行比较操作的右操作数,tvm.tensor或者 scalar类型。
  • operation : 操作类型,支持eq, ne, lt, gt, ge, le,默认为lt。
  • mode:模式类型,支持bool和bit,默认为bool。

其中,进行比较操作的左操作数和右操作数的数据类型需要一致。

  • 昇腾310 AI处理器:支持的数据类型有float16
  • 昇腾910 AI处理器:支持的数据类型有float16、float32、int64
  • 昇腾310P AI处理器:支持的数据类型有float16、float32、int64、uint64

返回值

res_tensor:结果tensor,tvm.tensor类型。mode为bool时数据类型为bool,mode为bit时数据类型为uint8。

约束说明

  • 当mode为bool时,调用tbe.dsl.build接口进行编译时,传入的config参数中需要设置bool_storage_as_1bit为False。否则会出现某些shape的输出结果不符合预期。

    因为bool_storage_as_1bit的默认值为True,表示按照1 bit进行存储。

    编译配置config如下所示:

    with tvm.target.cce():
        schedule = tbe.dsl.auto_schedule(res)
    config = {"name": kernel_name,
              "tensor_list": [data_x, data_y, res],
              "bool_storage_as_1bit": False}
    tbe.dsl.build(schedule, config)
  • 当mode为bit时,左操作数的shape的最后一维必须被8整除。
  • 如果右操作数也为tensor,那么两个tensor的shape需要相同。
  • 针对昇腾310P AI处理器,当数据类型为int64与uint64时,仅支持mode为bool模式。

支持的芯片型号

昇腾310 AI处理器

昇腾910 AI处理器

昇腾310P AI处理器

调用示例

from tbe import tvm
from tbe import dsl
shape = (1024,1024)
input_dtype = "float16"
data1 = tvm.placeholder(shape, name="data1", dtype=input_dtype)
data2 = tvm.placeholder(shape, name="data2", dtype=input_dtype)
res = dsl.vcmp(data1, data2, 'lt', 'bit')
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词