运行脚本参数配置示例:
deepspeed pretrain_llama.py \ --DDP-impl local \ --tensor-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \ --num-layers 40 \ --hidden-size 5120 \ --ffn-hidden-size 13824 \ --num-attention-heads 40 \ --micro-batch-size 2 \ --global-batch-size 16 \ --seq-length 2048 \ --zero-stage 2 \ --lora-target-modules gate_proj up_proj down_proj \ --fp16 | tee logs/train.log
参数说明:
if isinstance(target, (ColumnParallelLinear, RowParallelLinear)): if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False new_module = LoraParallelLinear(adapter_name=adapter_name, parallel_linear=target, **kwargs)
if self.is_paralle_a: lora_a = RowParallelLinear(input_size=self.in_features, output_size=r, bias=False, input_is_parallel=True, skip_bias_add=True, dtype=torch.float32) # lora需要强制升格到32位精度,否则会溢出 lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) else: lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) lora_b = ColumnParallelLinear(input_size=r, output_size=self.out_features, bias=False, gather_output=False, dtype=torch.float32) self.lora_A.update(nn.ModuleDict({adapter_name: lora_a})) self.lora_B.update(nn.ModuleDict({adapter_name: lora_b}))
在ascendspeed/model文件夹下,根据“lora_modules.py”与l“ora_utils.py”进行lora微调。使用lora对模型进行处理时候,调用的配置函数为“def handle_mode_with_lora()”
lora_a = self.lora_A[self.active_adapter] lora_b = self.lora_B[self.active_adapter] lora_dropout = self.lora_dropout[self.active_adapter] scaling = self.scaling[self.active_adapter] lora_result = lora_a(lora_dropout(x)) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_b(lora_result) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_result * scaling result += lora_result
微调结果:
在LLaMA Alpaca 13B上分别通过deepspeed与megatron进行测度,结果为
megatron:1000steps@平均偏差:0.19% deeepspeed:1000steps@平均偏差:0.25%