torch_npu.npu_prefetch
功能描述
提供网络weight预取功能,将需要预取的权重搬到L2 Cache中(当前仅支持权重的预取,暂不支持KV cache的预取)。尤其在做较大Tensor的MatMul计算且需要搬移到L2 Cache的操作时,可通过该接口提前预取权重,适当提高模型性能,具体效果基于用户对并行的处理。
接口原型
torch_npu.npu_prefetch(Tensor self, Tensor? dependency, int max_size) -> ()
参数说明
- self:Tensor类型,表示需要预取的权重,不做数据处理,与数据类型和数据格式无关;输入不能含有空指针
- dependency:Tensor类型,表示开始预取的节点,单算子下不生效可为None,图模式下不可为None;不做数据处理,与数据类型和数据格式无关。
- max_size:int类型,取值需大于0,表示权重预取的最大size,超过预取权重的size时,会设置为权重的最大size。数据类型为int32、int64。
输出说明
无输出。
约束说明
无。
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
- 单算子多流并发调用
1 2 3 4 5 6 7 8 9 10 11
import torch import torch_npu s_cmo = torch.npu.Stream() x = torch.randn(10000, 10000, dtype=torch.float16).npu() y = torch.randn(10000, 1, dtype=torch.float16).npu() add = torch.add(x, 1) with torch.npu.stream(s_cmo): torch_npu.npu_prefetch(y, None, 10000000) abs = torch.abs(add) mul = torch.matmul(abs, abs) out = torch.matmul(mul, y)
- 图模式调用(图模式目前仅支持PyTorch 2.1版本)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.debug.graph_dump.type = 'pbtxt' npu_backend = tng.get_npu_backend(compiler_config=config) x = torch.randn(10000, 10000, dtype=torch.float16).npu() y = torch.randn(10000, 1, dtype=torch.float16).npu() class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self,x,y): add = torch.add(x, 1) torch_npu.npu_prefetch(y, add, 10000000) abs = torch.abs(add) mul = torch.matmul(abs, abs) out = torch.matmul(mul, y) return out npu_model = Model().npu() model = torch.compile(npu_model, backend=npu_backend, dynamic=False, fullgraph=True) output = model(x,y)
父主题: torch_npu