refine_shapes_for_broadcast

函数功能

在满足广播规则的前提下,对输入shape1和shape2补维至相同长度,如果补维后的两个shape在某一维度上的值均为1,则舍弃该维度,然后对新的输入shape1和shape2在连续且广播方向相同的轴上执行合并操作,在连续非广播轴上执行合并操作。

说明:输入shape1与shape2的长度可不相同,但经过补维度至相同长度后每一个维度需要满足广播操作的要求,即相同轴的维度值或者相同,或者其中一个值为1。

函数原型

def refine_shapes_for_broadcast(shape1, shape2)

参数说明

参数

说明

shape1

需要优化的shape1

shape2

需要优化的shape2

返回值说明

返回优化后的shape1与shape2。

约束说明

调用示例

from tbe.common.utils import shape_util
shape1, shape2= shape_util.refine_shapes_for_broadcast((1, 2, 3, 4, 1, 5, 6, 7), (2, 1, 1, 2, 1, 6, 7)) 

对shape2高维补1至和shape1长度相同,得到shape2为(1, 2, 1, 1, 2, 1, 6, 7)。

补维后,shape1和shape2的第0维均为1,则舍弃该维度,得到shape1为(2, 3, 4, 1, 5, 6, 7),shape2为(2, 1, 1, 2, 1, 6, 7)。

针对shape1,在连续的同方向广播轴1与2上进行合并,在连续的非广播轴5和6上进行合并,输出shape1 = (2, 12, 1, 5, 42)。

针对shape2,在连续的同方向广播轴1与2上进行合并,在连续的非广播轴5和6上进行合并,输出shape2 = (2, 1, 2, 1, 42)。