Checkpoint-Activations重计算

概述

重计算指在非梯度模式下执行目标操作的前向传播,只保留输入和结构参数,其他中间激活都被删除,反向计算时在梯度模式下重新计算这些激活,重建这部分反向图,进而实现梯度的正常回传。

与之前的ZeRO1、ZeRO2、ZeRO3实在DP组中做显存优化不同,ZeRO-R是在TP组中特别针对activation做显存优化,激活值主要存在于剩余状态中,因为对activation的存储是灵活的,其只是起到加速梯度计算的作用。设计时候也可以采用与上述一样的切割方式,每块计算资源上只是维护部分的activation,需要时再从其他地方聚合过来。

功能使用

训练脚本参数配置示例,加粗部分为流水并行功能配置参数示例:

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
       pretrain_llama.py \
       --DDP-impl local \
       --num-layers 30 \
       --hidden-size 6656 \
       --ffn-hidden-size 17920 \
       --num-attention-heads 52 \
       --micro-batch-size 1 \
       --global-batch-size 32 \
       --seq-length 2048 \
       --deepspeed-activation-checkpointing
       --partition-activations
       --contigious-checkpointing
       --checkpoint-in-cpu
       --synchronize-each-layer
       --profile-backward

参数说明:

重计算功能的定义位于“ascendspeed/initialize.py”文件中的“setup_deepspeed_random_and_activation_checkpointing”函数。使能开关位于ascendspeed/arguments.py中的_add_activation_checkpoint_args函数。

通过deepspeed库来进行调用,代码示例:
num_layers = args.num_layers // args.checkpoint_num_layers
num_layers = num_layers if args.num_layers % args.checkpoint_num_layers == 0 else num_layers + 1
if args.split_transformers:
        num_layers *= 2

deepspeed.checkpointing.configure(
    mpu,
    partition_activations=args.partition_activations,
    contiguous_checkpointing=args.contigious_checkpointing,
    num_checkpoints=num_layers,
    checkpoint_in_cpu=args.checkpoint_in_cpu,
    synchronize=args.synchronize_each_layer,
    profile=args.profile_backward)

可选DeepSpeed激活检查点功能。允许访问分区激活、连续内存优化和CPU检查点。激活检查点需要跟踪随机状态并为每个MP进程设置随机种子。

在ascendspeed/initialize.py中的_initialize_distributed函数完成调用:
# 设置DeepSpeed ZeRO-R,对activation进行优化
if args.deepspeed and args.deepspeed_activation_checkpointing:
    setup_deepspeed_random_and_activation_checkpointing(args)