from apex import amp
model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
loss = criterion(…) loss.backward() optimizer.step()
loss = criterion(…) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
更多混合精度模块的使用可参见官方文档。