当reduce轴为None时,返回输入shape去除1后的shape列表;当reduce轴不为None时,返回输入去除1后的shape列表及排序后的reduce轴列表,格式为(shape,axis)的元组。
def shape_refine(shape, reduce_axis=None, keep_dims=True):
参数 |
说明 |
---|---|
shape |
输入shape。 |
reduce_axis |
需要进行reduce的轴,可以是list、tuple或者int类型数值。 |
keep_dims |
是否保持维度数,bool型。 True代表保持维度, False表示降维。 默认值为True。 |
reduce轴对应的shape值为1,则不能进行优化。
示例一:
from tbe.common.utils import shape_util shape_util.shape_refine((32, 64, 64, 1), reduce_axis=None)
返回(32, 64, 64)。
示例二:
from tbe.common.utils import shape_util shape_util.shape_refine((32, 64, 64, 1), reduce_axis=[-2, -1], keep_dims=False)
返回([32, 64, 64], [2])。