class DropoutWithByteMask()

API接口

class DropoutWithByteMask(Module):

功能描述

应用NPU兼容的DropoutWithByteMask操作。

参数说明

输出说明

Output (Tensor) - 输出张量与输入张量的shape相同。

约束说明

示例

调用方式示例:
from torch_npu.contrib.module.npu_modules import DropoutWithByteMask
m = DropoutWithByteMask(p=0.5)
使用示例:
  >>> m = DropoutWithByteMask(p=0.5)
  >>> input = torch.randn(16, 16)
  >>> output = m(input)