(beta)torch_npu.utils.FlopsCounter

定义文件

torch_npu\utils\flops_count.py

接口原型

torch_npu.utils.FlopsCounter()

功能描述

Flops统计类,用于统计各个常见cube类算子的浮点计算Flops,采用单例模式。当前支持可统计Flops的算子:MM、BMM、AllgatherMM、ReduceScatterMM、FA。

参数说明

以下参数说明为初始化时此类的参数说明,用户可通过成员函数对参数进行修改。

首次创建(初始化)参数状态和通过成员函数对参数进行修改后,参数的状态如图1所示。
图1 参数状态

成员函数

支持的型号

Atlas A2 训练系列产品

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch_npu
 
def matmul():
    x = torch.randn(3, 4).npu()
    y = torch.randn(4, 3).npu()
    torch.matmul(x,y)
 
FlopsCounter = torch_npu.utils.FlopsCounter()
 
# 1.开启统计后进行统计
FlopsCounter.start()
matmul() # 算子计算
print(f"FlopsCounter.start():{FlopsCounter.get_flops()}") # 打印统计结果,含重计算的Flops和不含重计算的Flops累计
 
# 2. 暂停Flops不含重计算统计后进行统计
FlopsCounter.pause()
matmul() # 这里视作重计算操作
print(f"FlopsCounter.pause():{FlopsCounter.get_flops()}") # 仅含重计算Flops累计
 
# 3. 恢复Flops不含重计算统计后进行统计
FlopsCounter.resume()
matmul()
print(f"FlopsCounter.resume():{FlopsCounter.get_flops()}") # 含重计算Flops和不含重计算Flops均累计
 
# 4.关闭Flops统计
FlopsCounter.stop()
matmul()
print(f"FlopsCounter.stop():{FlopsCounter.get_flops()}") # 含重计算Flops和不含重计算Flops清0且均不累计