下载
中文
注册

broadcast关系

广播概念

broadcast(广播)描述了算子在运算期间如何处理具有不同形状的数组。在某些情况下,较小的数组可以“广播至”较大的数组,使两者shape互相兼容。

目前许多算子API支持参数的shape广播,通过该技术可以提高算法效率,但有时会导致内存使用效率降低,从而减缓计算速度。更多关于广播技术的介绍参考NumPy官网。

广播规则

一般进行广播计算时,需要理解以下规则:

  • 规则1:让所有输入数组都向形状最长的数组看齐。形状不足的部分通过在前面(左侧)填充1。

    说明: 形状其实就是指the number of dimensions。比如计算a+b,其中a.shape=(2, 2, 3)、b.shape=(2, 3),那么数组b将被broadcast为b.shape=(1, 2, 3)。

  • 规则2:如果两个数组的形状在任何维度上均不匹配,但是某个数组中某一个维度为1,则该维度中形状为1的数组将被拉伸以匹配另一个数组对应维度形状。

    说明: 本场景下,只需保证能在某一个维度做broadcast即可。比如计算a+b,其中a.shape=(1, 3)、b.shape=(3, 1),那么两个数组会broadcast为a.shape=(3, 3)、b.shape=(3, 3)。

  • 规则3:如果两个数组的形状在任何维度上均不匹配,且均没有等于1的维度,则会报错。

限制

当满足broadcast关系的两个输入a和b的数据类型或者推导后的类型在(COMPLEX64,COMPLEX128,DOUBLE,INT16,UINT16,UINT64)中时。除了要满足上述广播规则之外,还要满足,将连续的需广播轴和连续的不需广播轴合并之后维度小于6,不满足则会广播失败导致报错。

举例:

  • 当a.shape=(5, 1, 5, 1, 5, 1),b.shape=(5, 5, 5, 5, 5, 5),没有需要合并的轴,最后维度为6,广播报错。
  • 当a.shape=(5, 1, 5, 5, 1, 1),b.shape=(5, 5, 5, 5, 5, 5),在第2和3维都不需要广播,4和5维都需要广播,分别连续合并,合并后的维度为4,广播成功。