conv3d_backprop_input
功能说明
在给定6HD格式的Data和FracZ格式的Weight的情况下计算float16的3-D反卷积。
接口可以支持bias。
Data tensor 的shape是6HD,即(N, D, C1, H, W, C0);Weight Tensor 的shape是 FracZ,即 (KD*C1*KH*KW, Cout//C0_out, C0_out, C0)。
函数原型
conv3d_backprop_input(filter, out_backprop, filter_size, input_size, para_dict)
参数说明
- filter:3d卷积的Weight,tensor,FracZ格式,目前支持float16类型
- out_backprop:3d卷积的输出反向,目前支持float16类型
- filter_size:3d卷积的Weight
- input_size:3d卷积的FeatureMap
- para_dict:字典格式,包含各种参数,后续参数扩展一般都在para_dict
其中目前在para_dict里要传递的参数:
- strides:list类型,3d卷积在FeatureMap DHW方向的移动步长。
- pads:list类型,3d卷积在FeatureMap DHW方向的补边。
- dilations:list类型,3d卷积Filter DHW方向的膨胀系数。
- res_dtype:输出的数据类型。
- kernel_name:算子名称。
- group_dict:字典格式,包含支持3dx group卷积相关的参数,通过下列计算过程获得。
参数2:cout,Weight的batch维度大小。
参数3:groups,group卷积参数。
参数4:cout0,为tbe_platform.C0_SIZE,默认值为16。
参数5:cin0,为tbe_platform.C0_SIZE,默认值为16。
具体计算公式:
lcm(param1, param2),计算最小公倍数。
mag_factor0 = lcm(fmap_c // groups, cin0) // (fmap_c // groups)
mag_factor1 = lcm(cout // groups, cout0) // (cout // groups)
mag_factor = min(lcm(mag_factor0, mag_factor1), groups)
cin1_g = (mag_factor * fmap_c // groups + cin0 - 1) // cin0
cout_g = (mag_factor * cout // groups + cout0 - 1) // cout0 * cout0
group_dict = {"real_g": (groups + mag_factor - 1) // mag_factor,
"mag_factor": mag_factor,
"cin1_g": cin1_g,
"cout_g": cout_g,
"cin_ori": fmap_c,
"cout_ori": cout}
返回值
res_tensor:表示卷积计算的tensor,即卷积计算的结果输出。
约束说明
此接口暂不支持与其他TBE DSL计算接口混合使用。
支持的型号
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas 推理系列产品
Atlas 200/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
from tbe import tvm from tbe import dsl shape_dedy = (1, 2, 16, 15, 22, 16) out_backprop_dtype = "float16" input_sizes = [1, 4, 30, 44, 128] shape_filter_ncdhw = [256, 128, 2, 2, 2] shape_filter_frac = (64, 16, 16, 16) filter_dtype = "float16" dedy = tvm.placeholder(shape_dedy, name="dedy", dtype=out_backprop_dtype) filters = tvm.placeholder(shape_filter_frac, name="filter", dtype=filter_dtype) strides = [1, 2, 2, 2, 1] pads = [0, 0, 0, 0, 0, 0] dilations = (1, 1, 1, 1, 1) res_dtype = "float16" kernel_name = "conv3d_backprop_input_w_2_2_2_128_256_y_1_2_15_22_256_x_1_4_30_44_128_s_1_2_2_2_1_SAME_d_1_1_g_1" group_dict = {'real_g': 1, 'mag_factor': 1, 'cin1_g': 8, 'cout_g': 256, 'cin_ori': 128, 'cout_ori': 256} para_dict = { "strides": strides, "pads": pads, "dilations": dilations, "res_dtype": res_dtype, "kernel_name": kernel_name, "group_dict": group_dict } dedx = dsl.conv3d_backprop_input( filter=filters, out_backprop=dedy, filter_size=shape_filter_ncdhw, input_size=input_sizes, para_dict=para_dict )