把tensor var广播为形状为目标shape(第二个参数)的tensor,结果的数据类型由output_dtype指定。
如下A的shape为(2,1)即2行一列,广播为目标shape(2,3)即2行3列,则会将原来的一列扩展为一模一样的3列。
例如tensor var的shape为(2, 1, 64),目标shape为(2, 128, 64),调用broadcast接口的结果tensor的shape变为(2, 128, 64)。
您可以在“te/lang/cce/te_compute/broadcast_compute.py”查看接口定义。
输入tensor var的shape长度需要小于或等于目标shape的长度,若var的shape长度小于目标shape,会在var前面自动补1,补齐为与目标shape长度相同。var每个维度的大小要么与目标shape相等,要么为1,为1的维度会被broadcast到与目标shape的对应维度一致。
支持的数据类型:float16、float32、int32。
te.lang.cce.broadcast(var, shape, output_dtype=None)
res_tensor:由var扩展后得到的tensor,shape为参数指定的shape,数据类型为output_dtype。
Atlas 200/300/500 推理产品
Atlas 训练系列产品
如下代码所示,将shape为(1024,1)的tensor通过调用broadcast接口扩展至shape(1024,1024)。
outshape = (1024,1024) shape = (1024,1) input_dtype = "float16" data = tvm.placeholder(shape, name="data", dtype=input_dtype) res = te.lang.cce.broadcast(data, outshape)