下载
中文
注册

matmul

功能说明

矩阵乘,计算:tensor_c=trans_a(tensor_a) * trans_b(tensor_b) + tensor_bias。

tensor_a与tensor_b的shape后两维(经过对应转置)需要满足矩阵乘(M, K) * (K, N) = (M, N)。

函数原型

matmul(tensor_a, tensor_b, trans_a=False, trans_b=False, format_a="ND", format_b="ND", alpha_num=1.0, beta_num=1.0, dst_dtype="float16", tensor_bias=None, quantize_params=None, format_out=None, compress_index=None, attrs={ }, kernel_name="Matmul")

参数说明

  • tensor_a:A矩阵,tvm.tensor类型。
  • tensor_b:B矩阵,tvm.tensor类型。

    两个tensor支持的数据类型如下:

    Atlas 200/300/500 推理产品 支持的数据类型:float16, float32, int32

    Atlas 训练系列产品 支持的数据类型:float16, float32, int32

    Atlas 推理系列产品 支持的数据类型:float16, float32, int32

    Atlas 200/500 A2推理产品 支持的数据类型:float16, float32, bfloat16

    Atlas A2训练系列产品/Atlas 800I A2推理产品 支持的数据类型:float16, float32, bfloat16

  • trans_a:A矩阵是否转置,bool类型。
  • trans_b:B矩阵是否转置,bool类型。
  • format_a:输入矩阵A的数据排布格式,取值为:“ND”“FRACTAL_NZ”“FRACTAL_Z”,默认值为“ND”。
  • format_b:输入矩阵B的数据排布格式,取值为:“ND”“FRACTAL_NZ”“FRACTAL_Z”,默认值为“ND”。
  • alpha_num:扩展参数,当前暂不使用,默认值1.0。
  • beta_num:扩展参数,当前暂不使用,默认值1.0。
  • dst_dtype:输出数据类型,支持:float16与float32。
  • tensor_bias:默认值为None,如果取值不为空,矩阵A和矩阵B相乘后的计算结果加上tensor_bias。tensor_bias的shape支持broadcast,其数据类型要和dst_dtype保持一致。
  • quantize_params:量化相关参数。此参数在后续版本会过期废弃,新开发算子请不要使用此参数。
  • format_out:输出tensor的format,取值为:“ND”“FRACTAL_NZ”“FRACTAL_Z”
  • compress_index:压缩权重矩阵的索引
  • attrs:扩展参数字典
  • kernel_name:算子在内核中的名称(即生成的二进制文件与算子描述文件的名称)。

返回值

tensor_c:根据关系运算计算后得到的tensor,tvm.tensor类型。

约束说明

此接口暂不支持与其他TBE DSL计算接口混合使用。

支持的型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

Atlas 推理系列产品

Atlas 200/500 A2推理产品

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

from tbe import tvm
from tbe import dsl
a_shape = (1024, 256)
b_shape = (256, 512)
bias_shape = (512, )
in_dtype = "float16"
dst_dtype = "float32"
tensor_a = tvm.placeholder(a_shape, name='tensor_a', dtype=in_dtype)
tensor_b = tvm.placeholder(b_shape, name='tensor_b', dtype=in_dtype)
tensor_bias = tvm.placeholder(bias_shape, name='tensor_bias', dtype=dst_dtype)
res = dsl.matmul(tensor_a, tensor_b, False, False, dst_dtype=dst_dtype, tensor_bias=tensor_bias)