RMSNorm融合优化

问题分析

在诸如 LLaMA、LLaMA2 和 Baichuan 等大型语言模型(LLM)中,RMSNorm(Root Mean Square Norm)作为一种归一化技术被广泛应用。然而,由于 PyTorch 框架本身未直接提供RMSNorm运算符接口,模型中常以自定义形式实现 RMSNorm,这在一定程度上影响了执行效率。

解决方法

MindSpeed 针对上述情况,对 RMSNorm 操作进行了融合优化,将其集成到单个运算符中,有效减少了数据传输次数和临时存储需求。具体实施细节可参考《PyTorch 训练模型迁移调优指南》的“RmsNorm & RmsNormGrad”章节

使用场景

当模型选用 RMSNorm 作为其归一化方法,并且训练脚本中已包含以下配置时:

--normalization RMSNorm

使用方法

启用RMSNorm融合优化,需在训练脚本中加入以下参数配置:

--use-fused-rmsnorm

使用效果

启用融合算子后,不仅能够有效节省内存资源,还能提升模型训练和性能。