beta)torch_npu.contrib.function.roll

接口原型

torch_npu.contrib.function.roll(input1, shifts, dims):

功能描述

使用NPU亲和写法替换swin-transformer中的原生roll。

参数说明

输出说明

Tensor - shifted input。

支持的型号

调用示例

1
2
3
4
>>> from torch_npu.contrib.function import roll
>>> input1 = torch.randn(32, 56, 56, 16).npu()
>>> shift_size = 3
>>> shifted_x_npu = roll(input1, shifts=(-shift_size, -shift_size), dims=(1, 2))