下载
中文
注册

torch_npu.npu_prefetch

功能描述

提供网络weight预取功能,将需要预取的权重搬到L2 Cache中(当前仅支持权重的预取,暂不支持KV cache的预取)。尤其在做较大Tensor的MatMul计算且需要搬移到L2 Cache的操作时,可通过该接口提前预取权重,适当提高模型性能,具体效果基于用户对并行的处理。

接口原型

torch_npu.npu_prefetch(Tensor self, Tensor? dependency, int max_size) -> ()

参数说明

  • self:Tensor类型,表示需要预取的权重,不做数据处理,与数据类型和数据格式无关;输入不能含有空指针
  • dependency:Tensor类型,表示开始预取的节点,单算子下不生效可为None,图模式下不可为None;不做数据处理,与数据类型和数据格式无关。
  • max_size:int类型,取值需大于0,表示权重预取的最大size,超过预取权重的size时,会设置为权重的最大size。数据类型为int32、int64。

输出说明

无输出。

约束说明

无。

支持的型号

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

  • 单算子多流并发调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    import torch
    import torch_npu
    s_cmo = torch.npu.Stream()
    x = torch.randn(10000, 10000, dtype=torch.float16).npu()
    y = torch.randn(10000, 1, dtype=torch.float16).npu()
    add = torch.add(x, 1)
    with torch.npu.stream(s_cmo):
        torch_npu.npu_prefetch(y, None, 10000000)
    abs = torch.abs(add)
    mul = torch.matmul(abs, abs)
    out = torch.matmul(mul, y)
    
  • 图模式调用(图模式目前仅支持PyTorch 2.1版本)
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    import torch
    import torch_npu
    import torchair as tng
    from torchair.ge_concrete_graph import ge_apis as ge
    from torchair.configs.compiler_config import CompilerConfig
    
    config = CompilerConfig()
    config.debug.graph_dump.type = 'pbtxt'
    npu_backend = tng.get_npu_backend(compiler_config=config)
    x = torch.randn(10000, 10000, dtype=torch.float16).npu()
    y = torch.randn(10000, 1, dtype=torch.float16).npu()
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self,x,y):
            add = torch.add(x, 1)
            torch_npu.npu_prefetch(y, add, 10000000)
            abs = torch.abs(add)
            mul = torch.matmul(abs, abs)
            out = torch.matmul(mul, y)
            return out
    
    npu_model = Model().npu()
    model = torch.compile(npu_model, backend=npu_backend, dynamic=False, fullgraph=True)
    output = model(x,y)