降维训练脚本
环境依赖
训练模型
本章节涉及的脚本的默认存放路径为:“tools/train/reduction”。
- 训练模型。
python3 call_train.py --dataset_dir=Dataset_Dir --val_dataset_dir=./valid --generate_val=True --save_path=./modelsDr --dim=512 --npu=0 --ratio=4 --metric=L2 --mode=train --train_size=100000 --epochs=20 --train_batch_size=8192 --infer_batch_size=128 --learning_rate=0.0005 --log_stride=500 --construct_neighbors=100 --queries_validation=1000
参数
说明
dataset_dir
数据集路径,类型为string,必须设置。目前实现默认读取base.npy,query.npy和gt.npy。
若数据集为其他名称,可以自行实现数据集读取,并对该脚本“get_train_data”所在行做对应修改。
例如。原代码为:
# load dataset demo before training, modify here if you want to load your own dataset ##################################################################### learn, base = get_train_data(args.dataset_dir, args.train_size) #####################################################################
可修改为:
# load dataset demo before training, modify here if you want to load your own dataset ##################################################################### # learn, base = get_train_data(args.dataset_dir, args.train_size) learn = np.fromfile(YOUR_LEARN_DATASET_DIR, dtype=np.float32).reshape((-1, YOUR_DATA_DIM)) base = np.fromfile(YOUR_BASE_DATASET_DIR, dtype=np.float32).reshape((-1, YOUR_DATA_DIM)) #####################################################################
val_dataset_dir
“generate_val”为“True”时有效,生成验证集的存放路径,类型为string,默认值为“./validation/”。
generate_val
是否生成验证集。首次训练请设置为“True”。类型为bool,默认为“False”。
save_path
模型存放路径。类型为string,必须设置。
dim
可选,数据集维度。取值范围:[96, 128, 200, 256, 512, 2048]。类型为int,默认值为“512”。
npu
训练所用的DeviceId,即设备号。类型为int。
仅支持单卡训练,默认为CPU训练。
ratio
可选,降维比例。取值范围:[2, 4, 8, 16]。类型为int,默认值为“8”。
metric
训练模型时的距离度量标准,可选L2或IP。类型为string,默认值为“L2”。
mode
可选,范围为[“train”,“infer”,“test”],但当前仅支持“train”,默认为“train”,无需修改。
train_size
训练集大小,取值范围小于整个数据集样本个数。用于读取数据集时随机采样部分数据进行训练。类型为int。
若自行实现数据集读取,请根据train_size进行采样以防止训练速度过慢。
默认值为“100000”,修改时要求该值大于“0”。
epochs
训练迭代轮数。类型为int。迭代次数设置过大,会显著增加训练时长。,默认为“30”,修改时要求该值大于0。
train_batch_size
训练时的batch大小,默认为“8192”,类型为int。修改时要求该值大于“0”。
infer_batch_size
推理时的batch大小,默认为“128”。类型为int。修改时要求该值大于“0”。
learning_rate
学习率大小,默认为“0.0005”。类型为float。修改时要求该值大于“0”。
log_stride
训练日志打印间隔(step),默认为“500”。类型为int。修改时要求该值大于“0”。
construct_neighbors
构造训练集时所取的近邻的范围,用于构造降维所需的特殊训练集结构,默认为“100”。应根据数据集中每个人所对应的人脸数修改。类型为int。修改时要求该值大于“0”。
queries_validation
构造验证集时所需查询向量的数量,类型为int。默认为“1000”,修改时要求该值大于0。
--help | -h
查询帮助信息。
- 生成OM模型。执行训练脚本前,先执行如下命令设置环境变量(根据CANN软件包的实际安装路径修改)。
source /usr/local/Ascend/ascend-toolkit/set_env.sh export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH
- 生成精度为32的om模型。
bash atc.sh {save_path} {om_name} {input_shape}
- 生成精度为16的om模型
bash atc_16.sh {save_path} {om_name} {input_shape}
- {save_path}:必选,表示模型存储的路径。
- {om_name}:可选,表示生成OM模型的名字,默认与onnx模型名字相同。
- {input_shape}:可选,默认为onnx模型的输入维度,格式为actual_input_1:infer_batch_size,dim,建议使用默认值,不建议修改。
- bash atc.sh和bash atc_16.sh仅支持Atlas 推理系列产品。
- 生成精度为32的om模型。