TensorFormat
TBE提供了TensorFormat结构体用于定义算子输入输出数据的数据排布格式。
定义如下所示:
class TensorFormat: ND = "ND" NCHW = "NCHW" NHWC = "NHWC" NDHWC = "NDHWC" NCDHW = "NCDHW" CHWN = "CHWN" NC1HWC0 = "NC1HWC0" NC1HWC0_C04 = "NC1HWC0_C04" NDC1HWC0 = "NDC1HWC0" FRACTAL_NZ = "FRACTAL_NZ" HWCN = "HWCN" DHWCN = "DHWCN" FRACTAL_Z = "FRACTAL_Z" FRACTAL_Z_C04 = "FRACTAL_Z_C04" C1HWNCoC0 = "C1HWNCoC0" FRACTAL_Z_3D = "FRACTAL_Z_3D" FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM" ND_RNN_BIAS = "ND_RNN_BIAS"
使用示例:
from tbe.common.utils.para_check import OpParamInfoKey, TensorFormat if x.get(OpParamInfoKey.FORMAT) == TensorFormat.NCHW # do something
父主题: 数据结构定义