接口原型
torch_npu.optim.NpuFusedBertAdam(params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.99, e=1e-6, weight_decay=0.01, max_grad_norm=-1)
功能描述
通过张量融合实现的 BertAdam 算法。
参数说明
- params:模型参数或模型参数组。
- lr:学习率(默认值:1e-3)。
- warmup:t_total的warmup比例(默认值:-1,表示不进行warmup)。
- t_total:学习率调整的步数(默认值:-1,表示固定学习率)。
- schedule:学习率warmup策略(默认值:'warmup_linear')。
- b1:Adams b1(默认值:0.9)。
- b2:Adams b2(默认值:0.99)。
- e:Adams epsilon(默认值:1e-6)。
- weight_decay:权重衰减(默认值:0.01)。
- max_grad_norm:最大梯度正则(默认值:1.0,-1表示不做裁剪)。
调用示例
opt = torch_npu.optim.NpuFusedBertAdam(model.parameters(), lr=0.1, weight_decay=0.01, max_grad_norm=1.0)