下载
中文
注册

PyTorch框架使能图模式

昇腾平台支持PyTorch网络基于图模式执行,图模式执行场景下其软件架构如下图所示:

图1 PyTorch图模式软件架构

其中TorchAir(Torch Ascend Intermediate Representation)是Ascend Extension for PyTorch中支持图模式能力的扩展库,支持PyTorch网络在昇腾设备上进行图模式的训练和推理。TorchAir对接PyTorch的dynamo特性,将PyTorch的FX图转换为Ascend IR图,通过GE进行计算图的编译优化等操作,并下发到昇腾硬件执行。

TorchAir提供的昇腾编译后端能够作为参数传入PyTorch的torch.compile接口中,从而使能图模式执行,示例如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 先导入torch_npu,再导入torchair
import torch
import torch_npu
import torchair

# 定义模型Model
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x, y):
        return torch.add(x, y)

# 实例化模型model
model = Model()

# 从TorchAir获取NPU提供的默认backend
config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)

# 使用TorchAir的backend调用compile接口编译模型
model = torch.compile(model, backend=npu_backend)

# 执行编译后的model
x = torch.randn(2, 2)
y = torch.randn(2, 2)
model(x, y)

关于基于TorchAir的PyTorch图模式详细使用方法可参见PyTorch图模式使用指南(TorchAir)