conv3dx

功能说明

在给定6HD格式的Data和FracZ格式的Weight的情况下计算float16的3-D反卷积。

接口可以支持bias。

Data tensor 的shape是6HD,即(N, D, C1, H, W, C0);Weight Tensor 的shape是 FracZ,即 (KD*C1*KH*KW, Cout//C0_out, C0_out, C0)。

您可以在“te/lang/cce/te_compute/conv3d_backprop_input_compute.py”查看接口定义。

约束说明

此接口暂不支持与其他TBE DSL计算接口混合使用。

支持的数据类型:float16。

函数原型

te.lang.cce.conv3d_dx(filter, out_backprop, filter_size, input_size, para_dict)

参数说明

返回值

res_tensor:表示卷积计算的tensor,即卷积计算的结果输出。

支持的芯片型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

调用示例

import te
from tbe import tvm

shape_dedy = (1, 2, 16, 15, 22, 16)
out_backprop_dtype = "float16"
input_sizes = [1, 4, 30, 44, 128]
shape_filter_ncdhw = [256, 128, 2, 2, 2]
shape_filter_frac = (64, 16, 16, 16)
filter_dtype = "float16"

dedy = tvm.placeholder(shape_dedy, name="dedy",
                       dtype=out_backprop_dtype)

filters = tvm.placeholder(shape_filter_frac,
                          name="filter", dtype=filter_dtype)

strides = [1, 2, 2, 2, 1]
pads = [0, 0, 0, 0, 0, 0]
dilations = (1, 1, 1, 1, 1)
res_dtype = "float16"
kernel_name = "conv3d_backprop_input_w_2_2_2_128_256_y_1_2_15_22_256_x_1_4_30_44_128_s_1_2_2_2_1_SAME_d_1_1_g_1"
group_dict = {'real_g': 1, 'mag_factor': 1, 'cin1_g': 8, 'cout_g': 256, 'cin_ori': 128, 'cout_ori': 256}

para_dict = {
    "strides": strides,
    "pads": pads,
    "dilations": dilations,
    "res_dtype": res_dtype,
    "kernel_name": kernel_name,
    "group_dict": group_dict
}

dedx = te.lang.cce.conv3d_dx(
    filter=filters,
    out_backprop=dedy,
    filter_size=shape_filter_ncdhw,
    input_size=input_sizes,
    para_dict=para_dict
)