class NpuCachedDropout(torch.nn.Dropout):
在NPU设备上使用FairseqDropout。
from torch_npu.contrib.module import NpuDropPath NpuCachedDropout.enable_dropout_ensemble(model)
>>> model = NpuMNIST().to("npu") >>> x = torch.randn(2,10,16,16).to("npu") >>> NpuCachedDropout.enable_dropout_ensemble(model) >>> output = model(x)