使用

  1. 从APEX库中导入AMP。

    from apex import amp

  2. 初始化AMP,使其能对模型、优化器以及PyTorch内部函数进行必要的改动。

    model, optimizer = amp.initialize(model, optimizer, combine_grad=True)

  3. 标记反向传播.backward()发生的位置,这样AMP就可以进行Loss Scaling并清除每次迭代的状态。

    原始代码如下:
    loss = criterion(…) 
    loss.backward() 
    optimizer.step()
    修改后支持Loss Scaling的代码如下:
    loss = criterion(…) 
    with amp.scale_loss(loss, optimizer) as scaled_loss:     
        scaled_loss.backward() 
    optimizer.step()

更多混合精度模块的使用可参见官方文档