torch_npu\utils\flops_count.py
torch_npu.utils.FlopsCounter()
Flops统计类,用于统计各个常见cube类算子的浮点计算Flops,采用单例模式。当前支持可统计Flops的算子:MM、BMM、AllgatherMM、ReduceScatterMM、FA。
以下参数说明为初始化时此类的参数说明,用户可通过成员函数对参数进行修改。
开启Flops统计。FlopsCounter.start()设置开关标志位(isEnabled_)为true,进行Flops计算,统计含重计算的Flops。
关闭Flops统计。FlopsCounter.stop()设置开关标志位(isEnabled_)和暂停标志位(isPaused _)为false,不进行Flops计算,含重计算的Flops(traversedCount)和不含重计算的Flops(recordedCount)均不统计。且重置含重计算的Flops和不含重计算的Flops为0。
暂停Flops不含重计算的统计。FlopsCounter.pause()设置暂停标志位(isPaused _)为true,不含重计算的Flops(recordedCount)将不会被统计。
恢复Flops不含重计算的统计。设置暂停标志位(isPaused _)为false。暂停标志位(isPaused _)为false且开关标志位(isEnabled_)为true时,不含重计算的Flops(recordedCount)将会被统计。
获取统计结果。返回列表,包括不含重计算的Flops(recordedCount)和含重计算的Flops(traversedCount),例如[100, 200],100为不含重计算的Flops(recordedCount),200为含重计算的Flops(traversedCount)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch import torch_npu def matmul(): x = torch.randn(3, 4).npu() y = torch.randn(4, 3).npu() torch.matmul(x,y) FlopsCounter = torch_npu.utils.FlopsCounter() # 1.开启统计后进行统计 FlopsCounter.start() matmul() # 算子计算 print(f"FlopsCounter.start():{FlopsCounter.get_flops()}") # 打印统计结果,含重计算的Flops和不含重计算的Flops累计 # 2. 暂停Flops不含重计算统计后进行统计 FlopsCounter.pause() matmul() # 这里视作重计算操作 print(f"FlopsCounter.pause():{FlopsCounter.get_flops()}") # 仅含重计算Flops累计 # 3. 恢复Flops不含重计算统计后进行统计 FlopsCounter.resume() matmul() print(f"FlopsCounter.resume():{FlopsCounter.get_flops()}") # 含重计算Flops和不含重计算Flops均累计 # 4.关闭Flops统计 FlopsCounter.stop() matmul() print(f"FlopsCounter.stop():{FlopsCounter.get_flops()}") # 含重计算Flops和不含重计算Flops清0且均不累计 |