(beta)torch_npu.npu_bert_apply_adam
接口原型
torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, step_size=None, adam_mode=0, *, out=(var,m,v))
功能描述
adam结果计数。
参数说明
- 参数:
- var (Tensor) - float16或float32类型张量。
- m (Tensor) - 数据类型和shape与exp_avg相同。
- v (Tensor) - 数据类型和shape与exp_avg相同。
- lr (Scalar) - 数据类型与exp_avg相同。
- beta1 (Scalar) - 数据类型与exp_avg相同。
- beta2 (Scalar) - 数据类型与exp_avg相同。
- epsilon (Scalar) - 数据类型与exp_avg相同。
- grad (Tensor) - 数据类型和shape与exp_avg相同。
- max_grad_norm (Scalar) - 数据类型与exp_avg相同。
- global_grad_norm (Scalar) - 数据类型与exp_avg相同。
- weight_decay (Scalar) - 数据类型与exp_avg相同。
- step_size (Tensor,可选,默认值为None) - shape为(1, ),数据类型与exp_avg一致。
- adam_mode (Int,默认值为0) - 选择adam模式。0表示“adam”,1表示“mbert_adam”。
- 关键字参数:
- out (Tensor,可选) - 输出张量。
调用示例
>>> var_in = torch.rand(321538).uniform_(-32., 21.).npu() >>> m_in = torch.zeros(321538).npu() >>> v_in = torch.zeros(321538).npu() >>> grad = torch.rand(321538).uniform_(-0.05, 0.03).npu() >>> max_grad_norm = -1. >>> beta1 = 0.9 >>> beta2 = 0.99 >>> weight_decay = 0. >>> lr = 0. >>> epsilon = 1e-06 >>> global_grad_norm = 0. >>> var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, out=(var_in, m_in, v_in)) >>> var_out tensor([ 14.7733, -30.1218, -1.3647, ..., -16.6840, 7.1518, 8.4872], device='npu:0')
父主题: torch_npu