Megatron重计算

背景与挑战

在大模型的训练流程中,传统的实践要求存储前向传播阶段产生的激活值,以供后续反向传播过程中的梯度计算使用。这一需求导致了激活值保存数量随模型深度线性增长的现象,显著加剧了对硬件内存资源的压力。

解决方案

为应对上述挑战,提出了重计算策略。具体而言,在前向传播与损失函数计算阶段,即时释放不再需要的激活值内存空间,仅在反向传播时根据需要重新计算激活值。此方法通过有效缩短激活值的生命周期,显著减轻了内存负担,提升了整体的资源利用效率。

使用场景

重计算分为以下两种方式:
  • 选择性重计算(推荐):专注于对Transformer架构内的core_attention组件进行重计算。该策略保留了那些占用较小内存空间但重计算成本较高的激活值,同时,对占用较大内存但重计算成本相对较低的激活值执行激活重计算。此方法在保证模型性能的同时,实现了内存使用的高效管理。
  • 完全重计算:适用于内存资源极为受限的极端环境。在这种模式下,除了保存输入数据外,所有激活值均在需要时重新计算,最大限度地减少了对内存的依赖。

使用方法

同时配置--recompute-activations和--recompute-granularity full生效选择性重计算。

使用效果

采用重计算策略后,系统能够实现以下显著改进: