broadcast
功能说明
把tensor var广播为形状为目标shape(第二个参数)的tensor,结果的数据类型由output_dtype指定。
如下A的shape为(2,1)即2行一列,广播为目标shape(2,3)即2行3列,则会将原来的一列扩展为一模一样的3列。
例如tensor var的shape为(2, 1, 64),目标shape为(2, 128, 64),调用broadcast接口的结果tensor的shape变为(2, 128, 64)。
函数原型
broadcast(var, shape, output_dtype=None)
参数说明
- var:需要broadcast的数据,标量或者tensor类型。
Atlas 200/300/500 推理产品:支持的数据类型有float16、float32、int32。
Atlas 训练系列产品:支持的数据类型有int8,uint8,float16,int16,uint16,float32,int32,uint32,int64和uint64。
Atlas 推理系列产品:支持的数据类型有int8,uint8,float16,int16,uint16,float32,int32,uint32,int64和uint64。
Atlas 200/500 A2推理产品:支持的数据类型有float16、float32、int32、int64、uint64。
Atlas A2训练系列产品/Atlas 800I A2推理产品:支持的数据类型有float16、float32、int32、int64、uint64、complex32、complex64。
- shape:目标shape,进行broadcast操作的目标shape。
- output_dtype:输出数据类型,默认值var.dtype。
返回值
res_tensor:由var扩展后得到的tensor,shape为参数指定的shape,数据类型为output_dtype。
约束说明
输入tensor var的shape长度需要小于或等于目标shape的长度,若var的shape长度小于目标shape,会在var前面自动补1,补齐为与目标shape长度相同。var每个维度的大小要么与目标shape相等,要么为1,为1的维度会被broadcast到与目标shape的对应维度一致。
支持的型号
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas 推理系列产品
Atlas 200/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
如下代码所示,将shape为(1024,1)的tensor通过调用broadcast接口扩展至shape(1024,1024)。
from tbe import tvm from tbe import dsl outshape = (1024,1024) shape = (1024,1) input_dtype = "float16" data = tvm.placeholder(shape, name="data", dtype=input_dtype) res = dsl.broadcast(data, outshape)