文档
注册

beta)torch_npu.npu_bert_apply_adam

接口原型

torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, step_size=None, adam_mode=0, *, out=(var,m,v))

功能描述

adam结果计数。

参数说明

  • 参数:
    • var (Tensor) - float16或float32类型张量。
    • m (Tensor) - 数据类型和shape与exp_avg相同。
    • v (Tensor) - 数据类型和shape与exp_avg相同。
    • lr (Scalar) - 数据类型与exp_avg相同。
    • beta1 (Scalar) - 数据类型与exp_avg相同。
    • beta2 (Scalar) - 数据类型与exp_avg相同。
    • epsilon (Scalar) - 数据类型与exp_avg相同。
    • grad (Tensor) - 数据类型和shape与exp_avg相同。
    • max_grad_norm (Scalar) - 数据类型与exp_avg相同。
    • global_grad_norm (Scalar) - 数据类型与exp_avg相同。
    • weight_decay (Scalar) - 数据类型与exp_avg相同。
    • step_size (Tensor,可选,默认值为None) - shape为(1, ),数据类型与exp_avg一致。
    • adam_mode (Int,默认值为0) - 选择adam模式。0表示“adam”,1表示“mbert_adam”。
  • 关键字参数:
    • out (Tensor,可选) - 输出张量。

调用示例

>>> var_in = torch.rand(321538).uniform_(-32., 21.).npu()
>>> m_in = torch.zeros(321538).npu()
>>> v_in = torch.zeros(321538).npu()
>>> grad = torch.rand(321538).uniform_(-0.05, 0.03).npu()
>>> max_grad_norm = -1.
>>> beta1 = 0.9
>>> beta2 = 0.99
>>> weight_decay = 0.
>>> lr = 0.
>>> epsilon = 1e-06
>>> global_grad_norm = 0.
>>> var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, out=(var_in, m_in, v_in))
>>> var_out
tensor([ 14.7733, -30.1218,  -1.3647,  ..., -16.6840,   7.1518,   8.4872],
      device='npu:0')
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词