下载
中文
注册

GraphEvaluator

功能说明

针对某一个模型,根据模型的bin类型输入数据,提供一个python实例,可对该模型执行校准和推理的评估器。

函数原型

class GraphEvaluator(AutoCalibrationEvaluatorBase):

def __init__(self, data_dir, input_shape, data_types):

参数说明

参数名

输入/返回值

含义

使用限制

data_dir

输入

与模型匹配的bin格式数据集路径。

数据类型:string

参数值格式:"data/input1/;data/input2/"

使用约束:

  • 路径支持大小写字母(a-z,A-Z)、数字(0-9)、下划线(_)、中划线(-)、句点(.)、中文字符。
  • 若模型有多个输入,且每个输入有多个batch数据,则不同的输入数据必须存储在不同的目录中,目录中文件的名称必须按照升序排序。所有的输入数据路径必须放在双引号中,节点中间使用英文分号分隔。
  • 单个bin文件中存储的数组shape需要和input_shape中输入的shape相匹配,例如:单张图片bin存储的数组shape为1x224x224x3,则input_shape中输入的必须为1x224x224x3;如需多个bin做量化,则可通过调整batch_num取值实现。

input_shape

输入

模型输入的shape信息。

数据类型:string

参数值格式:"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2"。

使用约束:指定的节点必须放在双引号中,节点中间使用英文分号分隔。

data_types

输入

输入数据的类型。

数据类型:string

参数值格式:"float32;float64"

使用约束:若模型有多个输入,且数据类型不同,则需要分别指定不同输入的数据类型,指定的输入数据类型必须按照输入节点顺序依次放在双引号中,所有的输入数据类型必须放在双引号中,中间使用英文分号分隔。

返回值说明

一个python实例。

函数输出

无。

调用示例

1
2
3
4
5
6
import amct_tensorflow as amct

evaluator = amct.GraphEvaluator(
    data_dir="./data/input_bin/", 
    input_shape="input:32,3,224,224",
    data_types="float32")