下载
中文
注册

RopeOperation

功能

旋转位置编码(Rotary Position Embedding,RoPE),以旋转矩阵的方式在q、k中注入位置信息,使得attention计算时能感受到token的位置关系,在各大模型中,RoPE被广泛应用。RoPE以绝对位置编码的方式实现了相对位置编码,能有效保持位置信息相对关系,并且可以通过编码外推的方式支持超过训练长度的位置编码。

算子上下文

图1 RopeOperation算子上下文

计算公式

对于二维情况

假设空间是偶数维的,把原始空间切分一个个正交的二维子空间,在上面做独立的不同角度的旋转,可以扩展到高维空间。

定义

struct RopeParam {
    int32_t rotaryCoeff = 4;
    int32_t cosFormat = 0;
};

参数列表

成员名称

类型

默认值

描述

rotaryCoeff

int32_t

4

rope,旋转系数,对半旋转是2,支持配置2、4、headDim / 2、headDim。

cosFormat

int32_t

0

训练用参数,支持配置0或1。

rotaryCoeff参数选择与原始计算公式的对应关系如下:

其中m为token的位置,d为query或key的维度。

  • rotaryCoeff = 2的情况:

    如上式,query和key分别取一半,进行旋转操作后拼接,称为对半旋转,对应rotaryCoeff = 2,即1/2。

  • rotaryCoeff = 4的情况:

    以此类推,rotaryCoeff = 4对应于1/4,表示query和key分成前后两半,每一半按rotaryCoeff = 2的情况处理。

  • rotaryCoeff = headDim的情况:

    如上式,query和key中的维度两两一组交错排列,其中图示的cos和sin列表经过了concat操作,表现为连续相同的2个cos、sin值。

    如果不经过concat操作,cos和sin列表的长度为上式中的一半,即headDim/2。算子也支持此种情况,此时需要将rotaryCoeff = headDim/2。

    换而言之,使用上图的计算模式只需要满足rotaryCoeff与cos或sin输入tensor的最后一维大小相等。

输入

参数

维度

数据类型

格式

描述

query

[ntokens, hiddenSizeQ]

float16/bf16

ND

当前step多个token的query。

key

[ntokens, hiddenSizeK]

float16/bf16

ND

当前step多个token的key。

cos

[ntokens, headDim] / [ntokens, headDim / 2]

float16/float/bf16

ND

  • 当cos的第二个维度与参数rotaryCoeff不相等时,其值为headDim。
  • ROPE高精度模式,需要输入cos的数据类型为float时生效。

sin

[ntokens, headDim] / [ntokens, headDim/ 2]

float16/float/bf16

ND

  • 当sin的第二个维度与参数rotaryCoeff不相等时,其值为headDim。
  • ROPE高精度模式,需要输入sin的数据类型为float时生效。

seqlen

[batch]

uint32/int32

ND

-

输出

参数

维度

数据类型

格式

描述

ropeQ

[ntokens, hiddenSizeQ]

float16/bf16

ND

旋转后的query。

ropeK

[ntokens, hiddenSizeK]

float16/bf16

ND

旋转后的key。

约束

  • 输入tensor数据类型需保持一致,高精度模式例外。
  • cos、sin传入数据类型为float时,中间计算结果以float保存。
  • hiddenSizeQ和hiddenSizeK必须是headDim的整数倍,满足hiddenSizeQ = headDim * headNumQ、hiddenSizeK = headDim * headNumK,其中headNumQ可以大于headNumK。
  • ntokens = sum(seqlen[i]),i=0...batch-1。
  • query和key要求两维,有部分模型使用了4维,这种情况下维度是:

    [batch, seqlen, headNum, headDim];对应的ropeQ、ropeK也是四维,维度输入输出对应。

  • Decoder阶段要取cos和sin表中seqlen对应的cos/sin值输入。
  • 多batch场景需要组合使用gather算子。