升级PyTorch框架至3.0.0版本后类transformer类模型的性能下降
2023/06/06
125
问题信息
问题来源 | 产品大类 | 关键字 |
---|---|---|
官方 | 模型训练 | -- |
问题现象描述
使用3.0.0版本的PyTorch框架训练时,类transformer类模型出现性能下降问题。
解决措施
修改模型训练脚本,在import torch_npu后添加以下代码。
def main(config): option = {} option["MM_BMM_ND_ENABLE"] = "disable" torch.npu.set_option(option)
本页内容