多模型、损失函数和优化器场景是指在神经网络中同时存在多个模型、多个损失函数和优化器的场景。
导入AMP模块,定义两个简单的神经网络。
import time import torch import torch.nn as nn import torch_npu from torch_npu.npu import amp from torch.utils.data import Dataset, DataLoader import torchvision device = torch.device('npu:0') # 用户请自行定义训练设备 # 定义一个简单的神经网络 class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.net = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=1), nn.MaxPool2d(kernel_size=2), nn.Conv2d(16, 32, 3, 1, 1), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(32*7*7, 16), nn.ReLU(), nn.Linear(16, 10) ) def forward(self, x): return self.net(x) # 定义第二个相似的神经网络,增加一层卷积层。 class CNN_2(nn.Module): def __init__(self): super(CNN_2, self).__init__() self.net = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=1), nn.MaxPool2d(kernel_size=2), nn.Conv2d(16, 32, 3, 1, 1), nn.MaxPool2d(2), nn.Conv2d(32, 32, 3, 1, 1), nn.Flatten(), nn.Linear(32*7*7, 16), nn.ReLU(), nn.Linear(16, 10) ) def forward(self, x): return self.net(x)
从torchvision中获取训练数据集,设置训练相关的参数batch_size、epochs。设置两个优化器,并在模型、优化器定义之后,定义AMP功能中的GradScaler。
train_data = torchvision.datasets.MNIST( root='mnist', download=True, train=True, transform=torchvision.transforms.ToTensor() ) batch_size = 64 model0 = CNN().to(device) model1 = CNN_2().to(device) train_dataloader = DataLoader(train_data, batch_size=batch_size) # 定义DataLoader loss_func = nn.CrossEntropyLoss().to(device) # 定义损失函数 optimizer0 = torch.optim.SGD(model0.parameters(), lr=0.1) # 定义优化器0 optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.1) # 定义优化器1 scaler = amp.GradScaler() # 在模型、优化器定义之后,定义GradScaler epochs = 10
在训练代码中添加AMP功能相关的代码开启AMP,对多个损失函数和优化器进行计算。
for epo in range(epochs): for imgs, labels in train_dataloader: imgs = imgs.to(device) labels = labels.to(device) with amp.autocast(): outputs0 = model0(imgs) # 前向计算 outputs1 = model1(imgs) loss0 = loss_func(2*outputs0+3*outputs1, labels) # 损失函数计算 loss1 = loss_func(3*outputs0-5*outputs1, labels) optimizer0.zero_grad() optimizer1.zero_grad() # 进行反向传播前后的loss缩放、参数更新 scaler.scale(loss0).backward(retain_graph=True) # loss缩放并反向转播 scaler.scale(loss1).backward() scaler.step(optimizer0) # 更新参数(自动unscaling) scaler.step(optimizer1) scaler.update() # 基于动态Loss Scale更新loss_scaling系数