片上内存显存侧动态扩容模式
TensorFlow对Embedding的支持是通过变量实现的,用户需要预估每个表的大小,再通过create_table接口创建变量。Embedding表的大小在一开始就确认,后期无法扩大或者减小,这可能会导致显存的浪费或者空间不足。在推荐场景下,多个稀疏表的大小无法预估,为更好的适配用户场景及需求,增加片上内存稀疏表自动扩容功能,即显存随着模型训练增长。
片上内存显存侧动态扩容模式下,不支持特征淘汰。
适配模型
关键步骤操作参考如下。
- 初始化框架。
调用init接口,设置参数use_dynamic_expansion = True表示启用动态扩容功能,该参数默认为“False”。
- 稀疏优化器导入。
调用mx_rec.optimizers包中对应优化器的create_hash_optimizer_by_address接口来创建稀疏表sparse_optimizer。具体可用优化器参考如下。
- 获取嵌入表示结果(emb)和映射地址(addr)。
使用tf.get_collection("ASCEND_SPARSE_LOOKUP_LOCAL_EMB")接口获取训练用的嵌入表示结果,使用tf.get_collection("ASCEND_SPARSE_LOOKUP_ID_OFFSET")接口获取训练用的映射地址。
- 反向梯度计算。
使用tf.gradients(loss, emb)接口对3获取的嵌入表示结果求导,得到梯度(grad)。
- 反向稀疏表更新。
使用2.sparse优化器导入。创建的sparse_optimizer.apply_gradients([grad, addr])接口对映射地址地址对应位置的稀疏表进行更新。
示例代码
- 初始化框架。
use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) init(use_mpi, train_steps=args.train_steps, eval_steps=args.eval_steps, use_dynamic_expansion=use_dynamic_expansion)
- 稀疏优化器导入。
def get_dense_and_sparse_optimizer(cfg): dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) use_dynamic_expansion = get_use_dynamic_expansion() if use_dynamic_expansion: sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate) logging.info("optimizer lazy_adam_by_addr") else: sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate) logging.info("optimizer lazy_adam") return dense_optimizer, sparse_optimizer
- 获取嵌入表示结果和映射地。
train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET)
- 反向梯度计算。
local_grads = tf.gradients(loss, train_emb_list) # local_embedding
- 反向稀疏表更新。
grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars))
- 调用sparse_optimizer.apply_gradients(grads_and_vars)更新梯度时,若使用的vars(如address)是tensor而非variable,需要保证vars的维度和grads的第一个维度相等。
- train_address_list地址列表需要是有效合法的,需通过3. 获取映射地址获取。若使用非法地址,运行时会抛出AICore Error等错误。
父主题: 训练功能特性流程