RMSNorm融合优化

背景与挑战

在诸如LLaMA、LLaMA2和Baichuan等大型语言模型(LLM)中,RMSNorm(Root Mean Square Norm)作为一种归一化技术被广泛应用。然而,由于PyTorch框架本身未直接提供RMSNorm运算符接口,模型中常以自定义形式实现RMSNorm,这在一定程度上影响了执行效率。

解决方案

MindSpeed 针对上述情况,对RMSNorm操作进行了融合优化,将其集成到单个运算符中,有效减少了数据传输次数和临时存储需求。具体实施细节请参见《PyTorch 训练模型迁移调优指南》的“RmsNorm & RmsNormGrad”章节

使用场景

当模型选用RMSNorm作为其归一化方法,并且训练脚本中已包含以下配置时:

--normalization RMSNorm

使用方法

启用RMSNorm融合优化,需在训练脚本中加入以下参数配置:

--use-fused-rmsnorm

使用效果

启用融合算子后,不仅能够有效节省内存资源,还能提升模型训练和性能。