下载
中文
注册

脚本实现流程与关键文件介绍

训练脚本实现流程

Resnet50原始网络脚本是Estimator模型的API,属于TensorFlow的高阶API,此训练脚本的实现流程为:

表1 训练流程说明

序号

过程

描述

1

数据预处理

创建输入函数input_fn。

2

模型构建

构建模型函数model_fn。

3

运行配置

实例化Estimator,并传入Runconfig类对象作为运行参数。

4

执行训练

在Estimator上调用训练方法Estimator.train(),利用指定输入对模型进行固定步数的训练。

关键文件介绍

关键文件目录结构如下所示(只列出部分需要修改文件,更多文件请查看获取的ResNet原始网络脚本):
├── r1
│   ├── resnet       // resnet主目录
│        ├── imagenet_main.py      // 基于Imagenet数据集训练网络模型
│        ├── imagenet_preprocessing.py     // Imagenet数据集数据预处理模块
│        ├── resnet_model.py    // resnet模型文件
│        ├── resnet_run_loop.py    // 数据输入处理与运行循环(训练、验证、测试)
├── utils
│   ├── flags
│   │   ├── _base.py     //定义模型的通用参数并设置默认值
表2 关键文件作用及功能

文件名称

简介

imagenet_main.py

包含imagenet数据集数据预处理、模型构建定义、模型运行的相关函数接口。其中数据预处理部分包含get_filenames()、parse_record()、input_fn()、get_synth_input_fn(),_parse_example_proto()函数,模型部分包含ImagenetModel类、imagenet_model_fn()、run_cifar()、define_cifar_flags()函数。

imagenet_preprocessing.py

imagenet图像数据预处理接口,训练过程中包括使用提供的边界框对训练图像进行采样、将图像裁剪到采样边界框、随机翻转图像,然后调整到目标输出大小(不保留纵横比)。评估过程中使用图像大小调整(保留纵横比)和中央剪裁。

resnet_model.py

ResNet模型的实现,包括辅助构建ResNet模型的函数以及ResNet block定义函数。

resnet_run_loop.py

模型运行文件,包括输入处理和运行循环两部分,输入处理包括对输入数据进行解码和格式转换,输出image和label,还根据是否是训练过程对数据的随机化、批次、预读取等细节做出了设定;运行循环部分包括构建Estimator,然后进行训练和验证过程。总体来看,是将模型放置在具体的环境中,实现数据与误差在模型中的流动,进而利用梯度下降法更新模型参数。