class ModulatedDeformConv()

API接口

class ModulatedDeformConv(nn.Module):

功能描述

应用基于NPU的Modulated Deformable 2D卷积操作。

参数说明

约束说明

ModedDeformConv仅实现float32数据类型的操作。conv_offset中权重和偏置必须初始化为0。

示例

调用方式示例:
from torch_npu.contrib.module import ModulatedDeformConv
m = ModulatedDeformConv(32, 32, 1)
使用示例:
   >>> m = ModulatedDeformConv(32, 32, 1)
   >>> input_tensor = torch.randn(2, 32, 5, 5)
   >>> output = m(input_tensor)

   >>> x = torch.randn(2, 32, 7, 7) 
   >>> model = ModulatedDeformConv(32, 32, 3, 2, 1)

   >>> torch.npu.set_device(0)
   >>> x = x.npu()
   >>> model = model.npu()

   >>> o = model(x)
   >>> l = o.sum()
   >>> l.backward()
   >>> print(l)