cache_compile
函数功能
开启模型缓存编译功能时需要调用该接口实现模型缓存编译。
函数原型
def cache_compile(func, *, config: Optional[CompilerConfig] = None, dynamic: bool = True, cache_dir: Optional[str] = None, global_rank: Optional[int] = None, tp_rank: Optional[int] = None, pp_rank: Optional[int] = None, ge_cache: bool = False, **kwargs) -> Callable
参数说明
参数 |
输入/输出 |
说明 |
是否必选 |
||
---|---|---|---|---|---|
func |
输入 |
缓存编译的函数,只支持torch.nn.module常见method,例如prompt、decode、warmup等。 |
是 |
||
config |
输入 |
图编译配置,CompilerConfig类的实例化,缺省情况下采用TorchAir自动生成的配置。 说明:
本场景下不支持同时配置dynamo导图功能、使能RefData类型转换功能。 |
否 |
||
dynamic |
输入 |
是否按照输入动态trace。数据类型支持bool,不推荐输入非布尔型数据。 该参数继承了PyTorch原有特性,详细介绍请参考LINK。 默认True,进行动态trace。 |
否 |
||
cache_dir |
输入 |
缓存文件落盘的根目录。默认在“.torchair_cache”目录下。 |
否 |
||
global_rank |
输入 |
分布式训练时的rank,INT类型。取值范围为[0, world_size-1],其中world_size是参与分布式训练的总进程数。 一般情况下TorchAir会自动通过torch.distributed.get_rank()获取缺省值。 |
否 |
||
tp_rank |
输入 |
指张量模型并行rank,INT类型,取值是global_rank中划分为TP域的rank id。 |
否 |
||
pp_rank |
输入 |
指流水线并行rank,INT类型,取值是global_rank中划分为PP域的rank id。 |
否 |
||
custom_decompositions |
输入 |
手动指定模型运行时用到的decomposition(将较大算子操作分解为较简单或核心算子)。用户根据自身情况配置该参数,以Add算子为例调用代码如下:
|
否 |
||
ge_cache |
输入 |
是否缓存GE图编译结果。数据类型支持bool,不推荐输入非布尔型数据。
说明:
|
否 |
||
* |
输入 |
预留参数项,用于后续功能拓展。 |
否 |
返回值说明
返回一个Callable对象。
约束说明
- 缓存func函数只能被触发一次dynamo trace,换言之如果func在过程中发生重编译,则会放弃缓存。
- 对于发生多次trace(guard失效)的函数,需要进行一次函数封装来使缓存生效。
- func必须是method,即必须是module实例对象的方法,且该方法未被其他装饰器修饰。
- func必须能成整图,即必须支持full graph。
- 只支持推理模式,不支持带反向计算过程的func缓存。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import torch import torch_npu import torchair class Model(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(2, 1) self.linear2 = torch.nn.Linear(2, 1) for param in self.parameters(): torch.nn.init.ones_(param) # 实现缓存编译,同时开启图单流执行功能 config = torchair.CompilerConfig() config.ge_config.enable_single_stream = True self.cached_decode = torchair.inference.cache_compile(self.decode, config=config) def forward(self, x: InputMeta, kv: List[torch.Tensor]): return self.cached_decode(x, kv) def _forward(self, x, kv): return self.linear2(x.data) + self.linear2(kv[0]) # 重新封装为decode函数 def decode(self, x, y): return self._forward(x, y) |