文档
注册
评分
提单
论坛
小AI

TensorTrait

功能说明

GlobalTensorLocalTensor中通过ShapeInfo类型的成员变量来保存shape信息,可以通过SetShapeInfo、GetShapeInfo来进行设置或者获取,通常用于算子实现内部的shape信息保存和传递。在不使用上述ShapeInfo功能的情况下,不需要这些信息。此时可以使用TensorTrait定义不含ShapeInfo的GlobalTensor以及LocalTensor,以降低内存占用,提升运行性能。

定义原型

1
2
3
4
template <typename T>
struct TensorTrait {
    using LiteType = T;
};

参数说明

表1 TensorTrait结构体模板参数说明

参数名

描述

T

只支持如下基础数据类型:int4b_t、uint8_t、int8_t、int16_t、uint16_t、bfloat16_t、int32_t、uint32_t、int64_t、uint64_t、float、half 。

通过TensorTrait可以得到一个使用TensorTrait表达的Tensor数据类型:在TensorTrait结构体内部,使用using关键字定义了一个类型别名LiteType,与模板参数T类型一致

通过TensorTrait定义的LocalTensor/GlobalTensor不包含ShapeInfo信息。

例如:

LocalTensor<float>对应的不含ShapeInfo信息的Tensor为LocalTensor<TensorTrait<float>>。

约束说明

  • 同一接口不支持同时输入TensorTrait类型的GlobalTensor/LocalTensor和非TensorTrait类型的GlobalTensor/LocalTensor。
  • 非TensorTrait类型和TensorTrait类型的GlobalTensor/LocalTensor相互之间不支持拷贝构造和赋值运算符。
  • TensorTrait特性当前仅支持如下接口:
    表2 TensorTrait特性支持的接口列表

    接口分类

    接口名称

    备注

    基础API>内存管理与同步控制>TQue/TQueBind

    AllocTensor、FreeTensor、EnQue、DeQue

    _

    基础API>矢量计算>单目指令

    Exp、Ln、Abs、Reciprocal、Sqrt、Rsqrt、Not、Relu

    -

    基础API>矢量计算>双目指令

    Add、Sub、Mul、Div、Max、Min、And、Or、AddRelu、AddReluCast、AddDeqRelu、SubRelu、SubReluCast、MulAddDst、FusedMulAdd、FusedMulAddRelu、

    -

    基础API>矢量计算>标量双目指令

    Adds、Muls、Maxs、Mins、ShiftLeft、ShiftRight、LeakyRelu

    -

    基础API>数据搬运

    DataCopy、Copy

    切片数据搬运接口需要ShapeInfo信息,不支持输入TensorTrait类型的GlobalTensor/LocalTensor

调用示例

  • 双目指令使用TensorTrait样例
    1
    2
    3
    4
    5
    // 使用系统描述符TensorTrait
    AscendC::LocalTensor<AscendC::TensorTrait<half>> tensor1 = que1.DeQue<AscendC::TensorTrait<half>>();
    AscendC::LocalTensor<AscendC::TensorTrait<half>> tensor2 = que2.DeQue<AscendC::TensorTrait<half>>();
    AscendC::LocalTensor<AscendC::TensorTrait<half>> tensor3 = que3.AllocTensor<AscendC::TensorTrait<half>>();
    Add(tensor3, tensor1, tensor2, tensor3.GetSize());
    
  • 标量双目指令使用TensorTrait样例
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    #include "kernel_operator.h"
    class KernelBinaryScalarTrait {
    public:
        __aicore__ inline KernelBinaryScalarTrait() {}
        __aicore__ inline void Init(__gm__ uint8_t* src, __gm__ uint8_t* dstGm)
        {
            srcGlobal.SetGlobalBuffer((__gm__ int16_t*)src);
            dstGlobal.SetGlobalBuffer((__gm__ int16_t*)dstGm);
            pipe.InitBuffer(inQueueSrc, 1, 512 * sizeof(int16_t));
            pipe.InitBuffer(outQueueDst, 1, 512 * sizeof(int16_t));
        }
        __aicore__ inline void Process()
        {
            CopyIn();
            Compute();
            CopyOut();
        }
    private:
        __aicore__ inline void CopyIn()
        {
            AscendC::LocalTensor<AscendC::TensorTrait<int16_t>> srcLocal = inQueueSrc.AllocTensor<AscendC::TensorTrait<int16_t>>();
            AscendC::DataCopy(srcLocal, srcGlobal, 512);
            inQueueSrc.EnQue(srcLocal);
        }
        __aicore__ inline void Compute()
        {
            AscendC::LocalTensor<AscendC::TensorTrait<int16_t>> srcLocal = inQueueSrc.DeQue<AscendC::TensorTrait<int16_t>>();
            AscendC::LocalTensor<AscendC::TensorTrait<int16_t>> dstLocal = outQueueDst.AllocTensor<AscendC::TensorTrait<int16_t>>();
    
            uint64_t mask = 128;
            int16_t scalar = 2;
            // repeatTimes = 4, 128 elements one repeat, 512 elements total
           // dstBlkStride, srcBlkStride = 1, no gap between blocks in one repeat
           // dstRepStride, srcRepStride =8, no gap between repeats
            AscendC::Adds(dstLocal, srcLocal, scalar, mask, 4, {1, 1, 8, 8});
            
            outQueueDst.EnQue(dstLocal);
            inQueueSrc.FreeTensor(srcLocal);
        }
        __aicore__ inline void CopyOut()
        {
            AscendC::LocalTensor<AscendC::TensorTrait<int16_t>> dstLocal = outQueueDst.DeQue<AscendC::TensorTrait<int16_t>>();
            AscendC::DataCopy(dstGlobal, dstLocal, 512);
            outQueueDst.FreeTensor(dstLocal);
        }
    private:
        AscendC::TPipe pipe;
        AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueSrc;
        AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueDst;
        AscendC::GlobalTensor<AscendC::TensorTrait<int16_t>> srcGlobal, dstGlobal;
    };
    extern "C" __global__ __aicore__ void binary_scalar_trait_kernel(__gm__ uint8_t* src, __gm__ uint8_t* dstGm)
    {
        KernelBinaryScalarTrait op;
        op.Init(src, dstGm);
        op.Process();
    }
    
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词