下载
中文
注册

broadcast关系

广播概念

broadcast(广播)描述了算子在运算期间如何处理不同形状的张量(或数组)。大部分情况下,允许不同形状的张量(或数组)在进行逐元素操作时自动扩展其形状,使它们的维度相互兼容,通常形状较小的张量(或数组)会被“广播”为较大的张量(或数组)。

目前许多算子API支持参数shape广播,通过该技术可以提高算法效率,但有时会导致内存使用效率降低,从而减缓计算速度。更详细的广播技术介绍参考NumPy官网。

广播规则

一般进行广播计算时,需要理解以下规则:

  • 规则1:如果两个数组的维度数不一致,短形状数组向长形状数组看齐,通过在短形状数组的左侧填充1,直至维度数相同。

    说明1:维度数(Number of Dimensions)是指张量(或数组)对应shape的维度数,比如.shape=(1, 1, 2, 4),对应维度数是4。

    说明2:比如计算a+b,其中a.shape=(2, 2, 3)、b.shape=(2, 3),那么数组b将被broadcast为b.shape=(1, 2, 3)。

  • 规则2:如果两个数组的维度数一致,且某个数组的某一维度为1,则该维度为1的数组将被拉伸以匹配另一个数组对应维度形状。

    说明:本场景下,只需保证在某一个维度做broadcast即可。比如计算a+b,其中a.shape=(1, 3)、b.shape=(3, 1),那么两个数组会broadcast为a.shape=(3, 3)、b.shape=(3, 3)。

  • 规则3:如果两个数组的维度数不一致,且均没有等于1的维度,则会报错。

基于上述规则,广播机制一般先按规则1进行扩维,再按规则2进行形状拉伸,举例如下:

假设a.shape=(2,2,3),取值形如:
[[[1 2 3],[4 5 6]],
 [[1 2 3],[4 5 6]]]
假设b.shape=(2,3),取值形如:
[[1 2 3],
 [-1 -2 -3]]
根据规则1扩展维度,b.shape=(1,2,3),取值如下:
[[[1 2 3],
  [-1 -2 -3]]]
根据规则2拉伸形状,b.shape=(2,2,3),取值如下:
[[[1 2 3],[-1 -2 -3]],
 [[1 2 3],[-1 -2 -3]]]
计算a+b,实际结果如下:
 [[[2 4 6],[3 3 3]],
  [[2 4 6],[3 3 3]]]

限制

当满足broadcast关系的两个输入a和b的数据类型或者推导后的类型在(COMPLEX64、COMPLEX128、DOUBLE、INT16、UINT16、UINT64)中时,除了要满足上述广播规则,还需满足如下规则,否则广播失败,从而导致算子执行报错:

规则:连续的需要广播的轴和连续的不需广播的轴合并之后的维度要求小于6。

举例:

  • 当a.shape=(5, 1, 5, 1, 5, 1),b.shape=(5, 5, 5, 5, 5, 5),没有需要合并的轴,最后维度为6,广播报错。
  • 当a.shape=(5, 1, 5, 5, 1, 1),b.shape=(5, 5, 5, 5, 5, 5),在第2和3维都不需要广播,4和5维都需要广播,分别连续合并,合并后的维度为4,广播成功。