基础API通用说明
使用tensor高维切分计算API可充分发挥硬件优势,支持开发者控制指令的迭代执行和操作数的地址间隔,功能更加灵活。
矢量计算通过Vector计算单元完成,矢量计算的源操作数和目的操作数均通过Unified Buffer(UB)来进行存储。Vector计算单元每个迭代会从UB中取出8个datablock(每个datablock数据块内部地址连续,长度32Byte),进行计算,并写入对应的8个datablock中。下图为单次迭代内的8个datablock进行Exp计算的示意图。
- 矢量计算API支持开发者通过Repeat times来配置迭代次数,从而控制指令的多次迭代执行。假设Repeat times设置为2,矢量计算单元会进行2个迭代的计算,可计算出2 * 8(每个迭代8个datablock) * 32Byte(每个datablock32Byte) = 512Byte的结果。如果数据类型为half,则计算了256个元素。下图展示了2次迭代Exp计算的示意图。Repeat times不能超过255。
图2 2次迭代Exp计算
- 针对同一个迭代中的数据,可以通过mask参数进行掩码操作来控制实际参与计算的个数。下图为进行Abs计算时通过mask逐比特模式按位控制哪些元素参与计算的示意图,1表示参与计算,0表示不参与计算。
图3 通过mask参数进行掩码操作示意图
- 矢量计算单元还支持带间隔的向量计算,通过Block stride(单次迭代内不同datablock间地址步长)和Repeat stride(相邻迭代间相同datablock的地址步长)来进行配置。
- Block stride
如果需要控制单次迭代内,数据处理的步长,可以通过设置同一迭代内不同block的地址步长Block stride来实现。下图给出了单次迭代内非连续场景的示意图,示例中源操作数的Block Stride配置为2,表示单次迭代内不同datablock间地址步长(起始地址之间的间隔)为2个datablock。图4 单次迭代内非连续场景的示意图
- Repeat stride
当Repeat times大于1,需要多次迭代完成矢量计算时,您可以根据不同的使用场景合理设置相邻迭代间相同block的地址步长Repeat stride的值。
下图给出了多次迭代间非连续场景的示意图,示例中源操作数和目的操作数的Repeat stride均配置为9,表示相邻迭代间相同datablock起始地址之间的间隔为9个datablock。相同datablock是指datablock在迭代内的位置相同,比如下图中的src1和src9处于相邻迭代,在迭代内都是第一个datablock的位置,其间隔即为Repeat stride的数值。
图5 多次迭代间非连续场景的示意图
- Block stride
下文中给出了Block stride、Repeat stride、mask配置的详细配置说明和示例。
Block stride(同一迭代内不同datablock的地址步长)
- 连续计算,Block stride 设置为1,对同一迭代内的8个datablock数据连续进行处理。
- 非连续计算,Block stride值大于1(如取2),同一迭代内不同datablock之间在读取数据时出现一个block的间隔,如下图所示。
图6 Block stride不同取值举例
Repeat stride(相邻迭代间相同datablock的地址步长)
- 连续计算场景:假设定义一个Tensor供目的操作数和源操作数同时使用(即地址重叠),Repeat stride取值为8。此时,矢量计算单元第一次迭代读取连续8个datablock,第二轮迭代读取下一个连续的8个datablock,通过多次迭代即可完成所有输入数据的计算。
- 非连续计算场景:Repeat stride取值大于8(如取10)时,则相邻迭代间矢量计算单元读取的数据在地址上不连续,出现2个datablock的间隔。
- 反复计算场景:Repeat stride取值为0时,矢量计算单元会对首个连续的8个datablock进行反复读取和计算。
- 部分重复计算:Repeat stride取值大于0且小于8时,相邻迭代间部分数据会被矢量计算单元重复读取和计算,此种情形一般场景不涉及。
mask参数
mask用于控制每次迭代内参与计算的元素。可通过连续模式和逐比特模式两种方式进行设置。
- 连续模式:表示前面连续的多少个元素参与计算。数据类型为uint64_t。取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同(当前数据类型单次迭代时能处理的元素个数最大值为:256 / sizeof(数据类型))。当操作数的数据类型占比特位16位时(如half,uint16_t),mask∈[1, 128];当操作数为32位时(如float, int32_t),mask∈[1, 64]。
具体样例如下:
// int16_t数据类型单次迭代能处理的元素个数最大值为256/sizeof(int16_t) = 128,mask = 64,mask∈[1, 128],所以是合法输入 // repeatTimes = 1, 共128个元素,单次迭代能处理128个元素,故repeatTimes = 1 // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内连续读取和写入数据 // dstRepStride, src0RepStride, src1RepStride = 8, 迭代间的数据连续读取和写入 uint64_t mask = 64; Add(dstLocal, src0Local, src1Local, mask, 1, { 1, 1, 1, 8, 8, 8 }); 结果示例如下: 输入数据(src0Local): [1 2 3 ... 64 ...128] 输入数据(src1Local): [1 2 3 ... 64 ...128] 输出数据(dstLocal): [2 4 6 ... 128 undefined...undefined]
// int32_t数据类型单次迭代能处理的元素个数最大值为256/sizeof(int32_t) = 64,mask = 64,mask∈[1, 64],所以是合法输入 // repeatTimes = 1, 共64个元素,单次迭代能处理64个元素,故repeatTimes = 1 // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内连续读取和写入数据 // dstRepStride, src0RepStride, src1RepStride = 8, 迭代间的数据连续读取和写入 uint64_t mask = 64; Add(dstLocal, src0Local, src1Local, mask, 1, { 1, 1, 1, 8, 8, 8 }); 结果示例如下: 输入数据(src0Local): [1 2 3 ... 64] 输入数据(src1Local): [1 2 3 ... 64] 输出数据(dstLocal): [2 4 6 ... 128]
- 逐bit模式:可以按位控制哪些元素参与计算,bit位的值为1表示参与计算,0表示不参与。参数类型为长度为2的uint64_t类型数组。
参数取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。当操作数为16位时,mask[0]、mask[1]∈[0, 264-1],且mask[0]和mask[1]不可同时为0;当操作数为32位时,mask[1]为0,mask[0]∈(0, 264-1]。
具体样例如下:
// 数据类型为int16_t uint64_t mask[2] = {6148914691236517205, 6148914691236517205}; // repeatTimes = 1, 共128个元素,单次迭代能处理128个元素,故repeatTimes = 1。 // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内连续读取和写入数据。 // dstRepStride, src0RepStride, src1RepStride = 8, 迭代间的数据连续读取和写入。 Add(dstLocal, src0Local, src1Local, mask, 1, { 1, 1, 1, 8, 8, 8 }); 结果示例如下: 输入数据(src0Local): [1 2 3 ... 64 ...127 128] 输入数据(src1Local): [1 2 3 ... 64 ...127 128] 输出数据(dstLocal): [2 un 6 ... un ...254 undefined]
mask过程如下:
mask={6148914691236517205, 6148914691236517205}(注:6,148,914,691,236,517,205表示64位二进制数0b010101....01)
// 数据类型为int32_t uint64_t mask[2] = {6148914691236517205, 0}; // repeatTimes = 1, 共64个元素,单次迭代能处理64个元素,故repeatTimes = 1。 // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内连续读取和写入数据。 // dstRepStride, src0RepStride, src1RepStride = 8, 迭代间的数据连续读取和写入。 Add(dstLocal, src0Local, src1Local, mask, 1, { 1, 1, 1, 8, 8, 8 }); 结果示例如下: 输入数据(src0Local): [1 2 3 ... 63 64] 输入数据(src1Local): [1 2 3 ... 63 64] 输出数据(dstLocal): [2 un 6 ... 126 undefined]
mask过程如下:
mask={6148914691236517205, 0}(注:6,148,914,691,236,517,205表示64位二进制数0b010101....01)