下载
中文
注册

PyTorch框架使能图模式

昇腾平台支持PyTorch网络基于图模式执行,图模式执行场景下其软件架构如下图所示:

图1 PyTorch图模式软件架构

其中TorchAir(Torch Ascend Intermediate Representation)是Ascend Extension for PyTorch中支持图模式能力的扩展库,支持PyTorch网络在昇腾设备上进行图模式的训练和推理。TorchAir对接PyTorch的dynamo特性,将PyTorch的FX图转换为Ascend IR图,通过GE进行计算图的编译优化等操作,并下发到昇腾硬件执行。

TorchAir提供的昇腾编译后端能够作为参数传入PyTorch的torch.compile接口中,从而使能图模式执行,示例如下:

# 必须先导torch_npu再导torchair
import torch
import torch_npu
import torchair 

# 定义模型Model
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x, y):
        return torch.add(x, y)

# 实例化模型model
model = Model()

# 从torchair获取npu提供的默认backend
config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)

# 使用torchair的backend去调用compile接口编译模型
model = torch.compile(model, backend=npu_backend)

# 使用编译后的model去执行
x = torch.randn(2, 2)
y = torch.randn(2, 2)
model(x, y)

关于基于TorchAir的PyTorch图模式详细使用方法可参见PyTorch图模式使用指南(TorchAir)