TIK矢量计算

接口原型

TIK提供了大量计算类API以调用Vector计算资源展开计算,用户需要关注API参数的合理设置以及背后的原理。考虑到API参数的共通性,以下仅以矢量单目运算API为例进行阐述。矢量单目运算的函数原型如下所示:

instruction(mask, dst, src, repeat_times, dst_rep_stride, src_rep_stride)

连续地址计算

from tbe import tik

tik_instance = tik.Tik()
data_input_gm = tik_instance.Tensor("float32", (256,), name="data_input_gm", scope=tik.scope_gm)
data_input_ub = tik_instance.Tensor("float32", (256,), name="data_input_ub", scope=tik.scope_ubuf)
tik_instance.data_move(data_input_ub, data_input_gm, 0, 1, 32, 0, 0)
# 使用vec_abs对data_input_ub进行单目操作
tik_instance.vec_abs(64, data_input_ub, data_input_ub, 256//64, 8, 8)
# 后续的搬出操作

在上述案例中,我们使用了vec_abs,它的作用是按每个元素取绝对值,总共对256个float32都取了绝对值。

如下数据计算图:

一个Block记录了8个float32,256个float32刚好是32个Block。一次计算可以算256Byte,所以一次Vector操作计算8个Block(红色粗框),计算重复迭代4次,直观地看出,rep_stride是相邻两个迭代的Block起始地址的距离,单位是Block;像上述的这种连续计算的情形,rep_stride为8。

注意:这里的dst和src是同一个Tensor,用户在使用时需要查阅具体接口说明,是否支持dst和src是同一个Tensor。

非连续地址计算

from tbe import tik
tik_instance = tik.Tik()
data_input_gm = tik_instance.Tensor("float32", (256,), name="data_input_gm", scope=tik.scope_gm)
data_input_ub = tik_instance.Tensor("float32", (256,), name="data_input_ub", scope=tik.scope_ubuf)
data_output_ub = tik_instance.Tensor("float32", (272,), name="data_output_ub", scope=tik.scope_ubuf)
tik_instance.data_move(data_input_ub, data_input_gm, 0, 1, 32, 0, 0)
# 使用vec_abs对data_input_ub进行单目操作
tik_instance.vec_abs(32, data_output_ub, data_input_ub, 2, 18, 16)
# 后续的搬出操作

vec_abs接口的作用是按每个元素取绝对值,其中:

图2 非连续地址计算示例

给定src中每个Block里的8个float32数值都是一样的(如src中第一个标-1的block块记录了8个-1),红框所示为一次repeat计算所涉及的内容,灰色表示不改变的内容,dst中D表示Default数值,即Tensor中原先存储的值。

注意:当UB有offset的情形时,处理方式和data_move类似,详情可参见有偏移的连续地址搬运,当前主要需要关注AI处理器是否要求UB的32B地址对齐。

课后练习

请各位读者按照之前所学的内容,回答如下TIK算子的预期结果:

from tbe import tik

tik_instance = tik.Tik()
data_input_gm_1 = tik_instance.Tensor("int32", (256,), name="data_input_gm_1", scope=tik.scope_gm)
data_input_ub_1 = tik_instance.Tensor("int32", (192,), name="data_input_ub_1", scope=tik.scope_ubuf)
data_input_gm_2 = tik_instance.Tensor("int32", (288,), name="data_input_gm_2", scope=tik.scope_gm)
data_input_ub_2 = tik_instance.Tensor("int32", (256,), name="data_input_ub_2", scope=tik.scope_ubuf)
data_output_ub = tik_instance.Tensor("int32", (256,), name="data_output_ub", scope=tik.scope_ubuf)
data_output_gm = tik_instance.Tensor("int32", (192,), name="data_output_ub", scope=tik.scope_gm)

tik_instance.vec_dup(64, data_input_ub_1, 0, 3, 1, 8)
tik_instance.data_move(data_input_ub_1, data_input_gm_1, 0, 4, 4, 4, 2)
tik_instance.data_move(data_input_ub_2, data_input_gm_2[32], 0, 1, 32, 8, 8)
tik_instance.vec_add(64, data_output_ub, data_input_ub_1, data_input_ub_2, 3, 8, 8, 8)
tik_instance.data_move(data_output_gm, data_output_ub, 0, 1, 24, 8, 8)

输入数据(int32):

data_input_gm_1 = {1,2,3,...,256}

data_input_gm_2 = {1,2,3,...,288}

1. 求搬出去的数据data_output_gm的前40个数值?

2. 想一想data_output_gm全部的值会是多少?

【参考答案】:

# 192个全部赋值为0
tik_instance.vec_dup(64, data_input_ub_1, 0, 3, 8)
# 按照规定的数据排布方式进行数据搬运,前40个值为{1,2,3,4,...,31,32,0,0,0,0,0,0,0,0}
tik_instance.data_move(data_input_ub_1, data_input_gm_1, 0, 4, 4, 4, 2)
# 按照规定的数据排布方式进行数据搬运,前40个值为{33,34,35,36,...,70,71,72}
tik_instance.data_move(data_input_ub_2, data_input_gm_2[32], 0, 1, 32, 8, 8)
# 全部连续地相加,mask是满的,最常见的情形
tik_instance.vec_add(64, data_output_ub, data_input_ub_1, data_input_ub_2, 3, 8, 8, 8)
# 全部连续地搬运数据,结果为{34,36,38,40,...,92,94,96,65,66,67,68,69,70,71,72}
tik_instance.data_move(data_output_gm, data_output_ub, 0, 1, 24, 8, 8)
# 全部192个数值是:{34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,146,148,150,152,154,156,158,160,162,164,166,168,170,172,174,176,178,180,182,184,186,188,190,192,194,196,198,200,202,204,206,208,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,258,260,262,264,266,268,270,272,274,276,278,280,282,284,286,288,290,292,294,296,298,300,302,304,306,308,310,312,314,316,318,320,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,370,372,374,376,378,380,382,384,386,388,390,392,394,396,398,400,402,404,406,408,410,412,414,416,418,420,422,424,426,428,430,432,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224}