昇腾社区首页
中文
注册

KerasDistributeOptimizer构造函数

产品支持情况

产品

是否支持

Atlas A2 训练系列产品

Atlas 800I A2 推理产品

x

Atlas 200I/500 A2 推理产品

x

Atlas 推理系列产品

x

Atlas 训练系列产品

Atlas 200/300/500 推理产品

x

功能说明

KerasDistributeOptimizer类的构造函数,用于包装用户使用tf.Keras构造的脚本中的单机训练优化器,构造NPU分布式训练优化器。

函数原型

1
def __init__(self, optimizer, name="NpuKerasOptimizer", **kwargs)

参数说明

参数名

输入/输出

描述

optimizer

输入

用于梯度计算和更新权重的单机版训练优化器。

name

输入

优化器名称。

返回值

返回KerasDistributeOptimizer类对象。

调用示例

1
2
3
4
5
import tensorflow as tf
from npu_bridge.npu_init import *

model=xxx  
model.compile(loss='mean_squared_error', optimizer=KerasDistributeOptimizer(tf.keras.optimizers.SGD()))