文档
注册
评分
提单
论坛
小AI

样例代码

本场景构建了一个简单的神经网络作为样例代码,执行一个普通的float32训练,用于对比开启AMP训练后的加速效果。

构建神经网络

我们先构建一个简单的神经网络。

#引入模块
import time
import torch
import torch.nn as nn
import torch_npu
from torch.utils.data import Dataset, DataLoader
import torchvision

# 指定运行Device,用户请自行定义训练设备号
device = torch.device('npu:0')

# 定义一个简单的神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 16,
                      kernel_size = (3, 3),
                      stride = (1, 1),
                      padding = 1),
            nn.MaxPool2d(kernel_size = 2),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*7*7, 16),
            nn.ReLU(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.net(x)

数据集定义与加载

这一部分我们从torchvision中获取训练数据集,设置训练相关的参数batch_size和数据集加载Dataloader。

# 数据集获取
train_data = torchvision.datasets.MNIST(
    root = 'mnist',
    download = True,
    train = True,
    transform = torchvision.transforms.ToTensor()
)
# 定义batchsize
batch_size = 64 
# 定义DataLoader
train_dataloader = DataLoader(train_data, batch_size = batch_size)  

定义损失函数与优化器

定义损失函数与优化器。

# 把模型放到指定NPU上
model = CNN().to(device)
# 定义损失函数 
loss_func = nn.CrossEntropyLoss().to(device)
# 定义优化器   
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) 

模型训练

定义模型训练设置循环。
# 设置循环
epochs = 10
for i in range(epochs):
    for imgs, labels in train_dataloader:
        start_time = time.time()   # 记录训练开始时间
        imgs = imgs.to(device)     # 把img数据放到指定NPU上
        labels = labels.to(device) # 把label数据放到指定NPU上
        outputs = model(imgs)    # 前向计算
        loss = loss_func(outputs, labels)    # 损失函数计算
        optimizer.zero_grad()
        loss.backward()          # 损失函数反向计算
        optimizer.step()         # 更新优化器

模型保存

定义模型保存路径与方法。

PATH = "state_dict_model.pt"        # 创建保存路径
torch.save(model.state_dict(), PATH)# 保存模型
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词