在迁移完成后,为保证模型的性能,需要开启混合精度。用户可以根据场景选择引入APEX混合精度模块(推荐)或使用PyTorch框架(1.8.1版本及以上)自带的AMP功能。APEX模块的安装请参考相关README文档进行编译安装。
from apex import amp
model = ... optimizer = ... model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
loss = criterion(…) #将loss.backward()替换为如下形式 with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
model = ... optimizer = ... #在模型、优化器定义之后,使用AMP功能。 scaler=GradScaler() #创建缩放器 for epoch in epochs: for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) ...... scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
from torch.nn.parallel import DistributedDataParallel as DDP from apex import amp ... model = ... optimizer = ... #在模型、优化器定义之后,初始化APEX模块。 model, optimizer = amp.initialize(model, optimizer, combine_ddp=False)
from apex import amp ... model = ... optimizer = ... #在模型、优化器定义之后,初始化APEX模块。 model, optimizer = amp.initialize(model, optimizer, combine_ddp=True)
from torch.nn.parallel import DistributedDataParallel as DDP ... model = ... optimizer = ... #在模型、优化器定义之后,使用AMP功能。 scaler=GradScaler() model, optimizer = amp.initialize(model, optimizer, opt_level="O1") ... with autocast(): output = model(input) loss = loss_fn(output, target) ... scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
更多具体混合精度使用说明请参见混合精度说明。