文档
注册

混合精度适配网络样例

网络结构构建

构建一个简单的神经网络。

#引入模块 
import time 
import torch 
import torch.nn as nn 
import torch_npu 
from torch.utils.data import Dataset, DataLoader 
import torchvision 
 
# 指定运行Device,用户请自行定义训练设备号 
device = torch.device('npu:0') 
 
# 定义一个简单的神经网络 
class CNN(nn.Module): 
    def __init__(self): 
        super(CNN, self).__init__() 
        self.net = nn.Sequential( 
            nn.Conv2d(in_channels = 1, out_channels = 16, 
                      kernel_size = (3, 3), 
                      stride = (1, 1), 
                      padding = 1), 
            nn.MaxPool2d(kernel_size = 2), 
            nn.Conv2d(16, 32, 3, 1, 1), 
            nn.MaxPool2d(2), 
            nn.Flatten(), 
            nn.Linear(32*7*7, 16), 
            nn.ReLU(), 
            nn.Linear(16, 10) 
        ) 
    def forward(self, x): 
        return self.net(x)

数据集定义与加载

这一部分我们从torchvision中获取训练数据集,设置训练相关的参数batch_size和数据集加载Dataloader。

# 数据集获取 
train_data = torchvision.datasets.MNIST( 
    root = 'mnist', 
    download = True, 
    train = True, 
    transform = torchvision.transforms.ToTensor() 
) 
# 定义batchsize 
batch_size = 64  
# 定义DataLoader 
train_dataloader = DataLoader(train_data, batch_size = batch_size)  

损失函数与优化器定义

定义损失函数与优化器。

# 把模型放到指定NPU上 
model = CNN().to(device) 
# 定义损失函数  
loss_func = nn.CrossEntropyLoss().to(device) 
# 定义优化器    
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) 

模型训练过程构建

定义模型训练设置循环。

# 设置循环 
epochs = 10 
for i in range(epochs): 
    for imgs, labels in train_dataloader: 
        start_time = time.time()   # 记录训练开始时间 
        imgs = imgs.to(device)     # 把img数据放到指定NPU上 
        labels = labels.to(device) # 把label数据放到指定NPU上 
        outputs = model(imgs)    # 前向计算 
        loss = loss_func(outputs, labels)    # 损失函数计算 
        optimizer.zero_grad() 
        loss.backward()          # 损失函数反向计算 
        optimizer.step()         # 更新优化器

模型保存构建

定义模型保存路径与方法。

PATH = "state_dict_model.pt"        # 创建保存路径 
torch.save(model.state_dict(), PATH)# 保存模型
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词