下载
中文
注册

check_shape

函数功能

对tensor的shape进行校验。

需要校验的shape需要为list或者tuple类型,shape的每个dim需要为int类型。Shape的维度、rank以及总大小需要满足校验参数的要求。若校验不通过时,会抛出具体的校验错误异常(RuntimeError)。

函数原型

def check_shape(shape, min_dim=0, max_dim=DIM_LIMIT, min_rank=0, max_rank=RANK_LIMIT, min_size=0, max_size=SHAPE_SIZE_LIMIT, param_name=PARAM_NAME)

参数说明

参数

说明

shape

需要校验的shape,list或者tuple类型。

min_dim

允许的最小dim值,int类型,默认值为0。

max_dim

允许的最大dim值,int类型,默认值为DIM_LIMIT。

DIM_LIMIT = 2 ** 31

min_rank

允许的最小rank值,int类型,默认值为0。

max_rank

允许的最大rank值,int类型,默认值为RANK_LIMIT。

RANK_LIMIT = 8

min_size

允许的最小shape总size,int类型,默认值为0。

max_size

允许的最大shape总size,int类型,默认值为SHAPE_SIZE_LIMIT。

SHAPE_SIZE_LIMIT = 2 ** 31

param_name

参数名称,用于打印信息时辅助提示,默认值为NULL。

返回值说明

无返回值,若检验失败会抛出具体的校验错误异常。

约束说明

无。

调用示例

from tbe.common.utils import para_check
para_check.check_shape([2,3,4]) # 使用默认规则校验。
para_check.check_shape([2,3,4], min_dim=5) # 允许的最小dim为5,其他的限制按照默认值进行校验,该调用将会抛出异常。