重计算指在非梯度模式下执行目标操作的前向传播,只保留输入和结构参数,其他中间激活都被删除,反向计算时在梯度模式下重新计算这些激活,重建这部分反向图,进而实现梯度的正常回传。
与之前的ZeRO1、ZeRO2、ZeRO3实在DP组中做显存优化不同,ZeRO-R是在TP组中特别针对activation做显存优化,激活值主要存在于剩余状态中,因为对activation的存储是灵活的,其只是起到加速梯度计算的作用。设计时候也可以采用与上述一样的切割方式,每块计算资源上只是维护部分的activation,需要时再从其他地方聚合过来。
训练脚本参数配置示例,加粗部分为流水并行功能配置参数示例:
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ pretrain_llama.py \ --DDP-impl local \ --num-layers 30 \ --hidden-size 6656 \ --ffn-hidden-size 17920 \ --num-attention-heads 52 \ --micro-batch-size 1 \ --global-batch-size 32 \ --seq-length 2048 \ --deepspeed-activation-checkpointing --partition-activations --contigious-checkpointing --checkpoint-in-cpu --synchronize-each-layer --profile-backward
参数说明:
重计算功能的定义位于“ascendspeed/initialize.py”文件中的“setup_deepspeed_random_and_activation_checkpointing”函数。使能开关位于ascendspeed/arguments.py中的_add_activation_checkpoint_args函数。
num_layers = args.num_layers // args.checkpoint_num_layers num_layers = num_layers if args.num_layers % args.checkpoint_num_layers == 0 else num_layers + 1 if args.split_transformers: num_layers *= 2 deepspeed.checkpointing.configure( mpu, partition_activations=args.partition_activations, contiguous_checkpointing=args.contigious_checkpointing, num_checkpoints=num_layers, checkpoint_in_cpu=args.checkpoint_in_cpu, synchronize=args.synchronize_each_layer, profile=args.profile_backward)
可选DeepSpeed激活检查点功能。允许访问分区激活、连续内存优化和CPU检查点。激活检查点需要跟踪随机状态并为每个MP进程设置随机种子。
# 设置DeepSpeed ZeRO-R,对activation进行优化 if args.deepspeed and args.deepspeed_activation_checkpointing: setup_deepspeed_random_and_activation_checkpointing(args)