check_shape
函数功能
对tensor的shape进行校验。
需要校验的shape需要为list或者tuple类型,shape的每个dim需要为int类型。Shape的维度、rank以及总大小需要满足校验参数的要求。若校验不通过时,会抛出具体的校验错误异常(RuntimeError)。
函数原型
def check_shape(shape, min_dim=0, max_dim=DIM_LIMIT, min_rank=0, max_rank=RANK_LIMIT, min_size=0, max_size=SHAPE_SIZE_LIMIT, param_name=PARAM_NAME)
参数说明
参数 |
说明 |
---|---|
shape |
需要校验的shape,list或者tuple类型。 |
min_dim |
允许的最小dim值,int类型,默认值为0。 |
max_dim |
允许的最大dim值,int类型,默认值为DIM_LIMIT。 DIM_LIMIT = 2 ** 31 |
min_rank |
允许的最小rank值,int类型,默认值为0。 |
max_rank |
允许的最大rank值,int类型,默认值为RANK_LIMIT。 RANK_LIMIT = 8 |
min_size |
允许的最小shape总size,int类型,默认值为0。 |
max_size |
允许的最大shape总size,int类型,默认值为SHAPE_SIZE_LIMIT。 SHAPE_SIZE_LIMIT = 2 ** 31 |
param_name |
参数名称,用于打印信息时辅助提示,默认值为NULL。 |
返回值说明
无返回值,若检验失败会抛出具体的校验错误异常。
约束说明
无。
调用示例
from tbe.common.utils import para_check para_check.check_shape([2,3,4]) # 使用默认规则校验。 para_check.check_shape([2,3,4], min_dim=5) # 允许的最小dim为5,其他的限制按照默认值进行校验,该调用将会抛出异常。
父主题: 算子参数校验