broadcast

功能说明

把tensor var广播为形状为目标shape(第二个参数)的tensor,结果的数据类型由output_dtype指定。

如下A的shape为(2,1)即2行一列,广播为目标shape(2,3)即2行3列,则会将原来的一列扩展为一模一样的3列。

例如tensor var的shape为(2, 1, 64),目标shape为(2, 128, 64),调用broadcast接口的结果tensor的shape变为(2, 128, 64)。

您可以在“te/lang/cce/te_compute/broadcast_compute.py”查看接口定义。

约束说明

输入tensor var的shape长度需要小于或等于目标shape的长度,若var的shape长度小于目标shape,会在var前面自动补1,补齐为与目标shape长度相同。var每个维度的大小要么与目标shape相等,要么为1,为1的维度会被broadcast到与目标shape的对应维度一致。

支持的数据类型:float16、float32、int32。

函数原型

te.lang.cce.broadcast(var, shape, output_dtype=None)

参数说明

返回值

res_tensor:由var扩展后得到的tensor,shape为参数指定的shape,数据类型为output_dtype。

支持的芯片型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

调用示例

如下代码所示,将shape为(1024,1)的tensor通过调用broadcast接口扩展至shape(1024,1024)。

outshape = (1024,1024)
shape = (1024,1)
input_dtype = "float16"
data = tvm.placeholder(shape, name="data", dtype=input_dtype)
res = te.lang.cce.broadcast(data, outshape)