class roll()

API接口

class roll():

功能描述

使用NPU亲和写法替换swin-transformer中的原生roll。

参数说明

输出说明

Tensor - shifted input。

示例

调用方式示例:
from torch_npu.contrib.function import roll
shifted_x_npu = roll(input1, shifts=(-shift_size, -shift_size), dims=(1, 2))
使用示例:
       
        >>> input1 = torch.randn(32, 56, 56, 16).npu()
        >>> shift_size = 3
        >>> shifted_x_npu = roll(input1, shifts=(-shift_size, -shift_size), dims=(1, 2))