文档
注册

Torch对接DeepSpeed框架

  1. 使用业务用户登录到计算节点。

    业务用户非{MindIO-install-user}、HwHiAiUser、hwMindX用户,由用户根据实际情况决定。

  2. 进入DeepSpeed安装目录。
    cd deepspeed安装目录/runtime
  3. 修改engine.py文件。
    1. 执行以下命令,打开engine.py文件。
      vim engine.py
    2. 按“i”进入编辑模式,修改如下内容。
      • 在文件首行加入以下内容。
        import torch_mindio
      • 将torch.load函数替换为torch_mindio.load函数。
        替换前:
        optim_checkpoint = torch.load(optim_load_path,
                                      map_location=torch.device('cpu'))

        替换后:

        optim_checkpoint = torch_mindio.load(optim_load_path)
      • 将普通的torch.save函数替换torch_mindio.save函数。
        替换前:
        torch.save(state, save_path)

        替换后:

        torch_mindio.save(state, save_path)
      • 将包含torch.save函数的with open语句替换为torch_mindio.save函数。
        替换前:
        with open(self._get_optimizer_ckpt_name(save_dir, tag, expp_rank), 'wb') as fd:
            torch.save(optimizer_state, fd)
            fd.flush()

        替换后:

        torch_mindio.save(optimizer_state, self._get_optimizer_ckpt_name(save_dir, tag, expp_rank))
      • 替换DeepSpeedEngine._get_expert_ckpt_name函数。

        替换前:

                        expert_state_dict = torch.load(DeepSpeedEngine._get_expert_ckpt_name(
                            checkpoint_path,
                            -1, # -1 means ignore layer_id
                            global_expert_id,
                            tag,
                            mpu),
                            map_location=torch.device('cpu'))

        替换后:

                        expert_state_dict = torch_mindio.load(DeepSpeedEngine._get_expert_ckpt_name(
                            checkpoint_path,
                            -1, # -1 means ignore layer_id
                            global_expert_id,
                            tag,
                            mpu))
    3. 按“Esc”键,输入:wq!,按“Enter”保存并退出编辑。
  4. 修改module.py文件。
    1. 执行以下命令打开module.py文件。
      vim pipe/module.py
    2. 替换torch.save和torch.load,替换方式参见3.b3.c
  5. 修改state_dict_factory.py文件。
    1. 执行以下命令打开state_dict_factory.py文件。
      vim state_dict_factory.py
    2. 替换torch.save和torch.load,替换方式参见3.b3.c
  6. 完成35修改.py文件,DeepSpeed即可使用MindIO服务。
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词