torch_npu.contrib.module.LabelSmoothingCrossEntropy(nn.Module):
使用NPU API进行LabelSmoothing Cross Entropy。
Float - shape为(k, 5)和(k, 1)的张量。标签以0为基础。
1 2 3 4 5 6 7 | >>> from torch_npu.contrib.module import LabelSmoothingCrossEntropy >>> x = torch.randn(2, 10).npu() >>> y = torch.randint(0, 10, size=(2,)).npu() >>> x.requires_grad = True >>> m = LabelSmoothingCrossEntropy(10) >>> npu_output = m(x, y) >>> npu_output.backward() |