torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, step_size=None, adam_mode=0, *, out=(var,m,v))
adam结果计数。
>>> var_in = torch.rand(321538).uniform_(-32., 21.).npu() >>> m_in = torch.zeros(321538).npu() >>> v_in = torch.zeros(321538).npu() >>> grad = torch.rand(321538).uniform_(-0.05, 0.03).npu() >>> max_grad_norm = -1. >>> beta1 = 0.9 >>> beta2 = 0.99 >>> weight_decay = 0. >>> lr = 0. >>> epsilon = 1e-06 >>> global_grad_norm = 0. >>> var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, out=(var_in, m_in, v_in)) >>> var_out tensor([ 14.7733, -30.1218, -1.3647, ..., -16.6840, 7.1518, 8.4872], device='npu:0')