文档
注册

broadcast

功能说明

把tensor var广播为形状为目标shape(第二个参数)的tensor,结果的数据类型由output_dtype指定。

如下A的shape为(2,1)即2行一列,广播为目标shape(2,3)即2行3列,则会将原来的一列扩展为一模一样的3列。

例如tensor var的shape为(2, 1, 64),目标shape为(2, 128, 64),调用broadcast接口的结果tensor的shape变为(2, 128, 64)。

函数原型

broadcast(var, shape, output_dtype=None)

参数说明

  • var:需要broadcast的数据,标量或者tensor类型。

    昇腾310 AI处理器:支持的数据类型有float16、float32、int32。

    昇腾910 AI处理器:支持的数据类型有int8,uint8,float16,int16,uint16,float32,int32,uint32,int64和uint64。

    昇腾310P AI处理器:支持的数据类型有int8,uint8,float16,int16,uint16,float32,int32,uint32,int64和uint64。

  • shape:目标shape,进行broadcast操作的目标shape。
  • output_dtype:输出数据类型,默认值var.dtype。

返回值

res_tensor:由var扩展后得到的tensor,shape为参数指定的shape,数据类型为output_dtype。

约束说明

输入tensor var的shape长度需要小于或等于目标shape的长度,若var的shape长度小于目标shape,会在var前面自动补1,补齐为与目标shape长度相同。var每个维度的大小要么与目标shape相等,要么为1,为1的维度会被broadcast到与目标shape的对应维度一致。

支持的芯片型号

昇腾310 AI处理器

昇腾910 AI处理器

昇腾310P AI处理器

调用示例

如下代码所示,将shape为(1024,1)的tensor通过调用broadcast接口扩展至shape(1024,1024)。

from tbe import tvm
from tbe import dsl
outshape = (1024,1024)
shape = (1024,1)
input_dtype = "float16"
data = tvm.placeholder(shape, name="data", dtype=input_dtype)
res = dsl.broadcast(data, outshape)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词