ZeroOptimizer零冗余优化器

概述

与张量并行相似的张量分片功能,但是整个张量会及时重建以进行前向或反向的计算,优化器通过在数据并行进程间区分模型的状态(优化器、梯度、参数),而不是复制模型的状态,从而消除了内存冗余,使得集群的完整聚合内存可用。

整个过程主要分为三阶段,优化器状态分区(os)、添加梯度分区(os+g)、添加参数分区(os+g+p)。

具体的操作过程如下:

① 从optimizer states开始优化,将optimizer states分为若干部分,每张卡上维护一份:

整体的数据并行流程如下:

(1)每张卡上存一份完整的参数W。将一个batch的数据分成3份,每张卡各吃一份,做完一轮foward和backward后,各得一份梯度;

(2)对梯度做一次AllReduce,得到完整的梯度G;

(3)得到完整梯度G,就可以对W做更新。W的更新由optimizer states和梯度共同决定。由于每块GPU上只保管部分optimizer states,因此只能将相应的W(蓝色部分)进行更新。

(4)此时,每张卡上都有部分W没有完成更新(图中白色部分)。所以需要对W做一次All-Gather,从别的卡上把更新好的部分W取回来。

② 进一步,将梯度也拆分,每张卡上维护一部分梯度:

此时的数据并行的整体流程如下:

(1)每张卡上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,算得一份完整的梯度(下图中绿色+白色)。

(2)对梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度。例如对卡1,它负责维护G1,因此其他的卡只需要把G1对应位置的梯度发给GPU1做加总就可。汇总完毕后,白色块对GPU无用,可以从显存中移除。见下图:

(3)每张卡用自己对应的O和G去更新相应的W。更新完毕后,每块GPU维持了一块更新完毕的W。同理,对W做一次All-Gather,将其他卡算好的W同步到自己这来。

③ 最后将参数也进行拆分,使得每张卡上仅保留一份优化器状态、一份梯度信息、一份参数:

此时整体的数据并行流程如下:

(1)每张卡上只保留部分的参数W,将一个batch的数据分成3份,每张卡各吃一份;

(2)forward时候,对W做一次All-Gather,取回分布在其他卡上的参数信息W,从而得到一份完整的参数,forward做完立即将不属于自己的那部分参数抛弃;

(3)backward时候,对W做一次All-Gather,取回完整的W,做完后立即将不属于自己的那部分参数抛弃;

(4)做完backward之后,需要进行一次Reduce-Scatter,将其他卡上自己对应位置的那部分梯度进行聚合,后将不属于自己维护的那部分梯度G抛弃;

(5)用自己维护的那部分O与G完成自己所持有的那部W的更新,无需再做AllReduce的操作。

功能使用

在脚本中使用时,参数的使能开关如下:
python -m torch.distributed.run $DISTRIBUTED_ARGS \
    pretrain_bloom.py \
    --num-layers $NLAYERS \
    --hidden-size $NHIDDEN \
    --num-attention-heads $NHEADS \
    --seq-length $SEQ_LEN \
    --max-position-embeddings $SEQ_LEN \
    --micro-batch-size $MICRO_BATCH_SIZE \
    --data-impl mmap \
    --deepspeed \
    --deepspeed_config ${config_json} \
    --zero-stage ${ZERO_STAGE} \
    --deepspeed-activation-checkpointing  \
    --distributed-backend nccl \
    --sequence-parallel
在ascendspeed/training.py文件下的training_log函数中,定义zero_stage的使用:
if args.zero_stage > 0:
    # ZeRO partitions optimizer states
    opt_stats = get_accelerator().FloatTensor(opt_stats)
    torch.distributed.all_reduce(opt_stats, group=parallel_state.get_data_parallel_group())
    opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2)
    torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX,
                                 group=parallel_state.get_data_parallel_group())