实现“Transformer Attention Score”的融合计算,实现的计算公式如下:
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, int, int, int)
当前0、1为保留配置值,2为使能无效行计算,其功能是避免在计算过程中存在整行mask进而导致精度有损失,但是该配置会导致性能下降。
如果算子可判断出存在无效行场景,会自动使能无效行计算,例如sparse_mode为3,Sq > Skv场景。
比如真正的S长度列表为:2 2 2 2 2,则actual_seq_qlen传:2 4 6 8 10。
比如真正的S长度列表为:2 2 2 2 2,则actual_seq_kvlen传:2 4 6 8 10。
sparse_mode |
含义 |
备注 |
---|---|---|
0 |
defaultMask模式。 |
- |
1 |
allMask模式。 |
- |
2 |
leftUpCausal模式。 |
- |
3 |
rightDownCausal模式。 |
- |
4 |
band模式。 |
- |
5 |
prefix非压缩模式。 |
varlen场景不支持。 |
6 |
prefix压缩模式。 |
- |
7 |
varlen外切场景,rightDownCausal模式。 |
仅varlen场景支持。 |
8 |
varlen外切场景,leftUpCausal模式。 |
仅varlen场景支持。 |
atten_mask的工作原理为,在Mask为True的位置遮蔽query(Q)与key(K)的转置矩阵乘积的值,示意如下:
QKT矩阵在atten_mask为True的位置会被遮蔽,效果如下:
atten_mask应传入下三角矩阵,示意如下:
atten_mask应传入band形状矩阵,示意如下:
说明:next_tockens为负数时,pre_tockens取值必须大于等于next_tockens的绝对值,且next_tockens的绝对值小于Skv。
说明:pre_tockens为负数时,next_tockens取值必须大于等于pre_tockens的绝对值,且pre_tockens的绝对值小于Sq。
该场景下忽略next_tockens、pre_tockens取值,Masked QKT矩阵示意如下:
传入的atten_mask为优化后的压缩下三角矩阵(2048*2048),压缩下三角矩阵示意(下同):
该场景下忽略pre_tockens、next_tockens取值,atten_mask矩阵数据格式须为BNSS或B1SS,Masked QKT矩阵示意如下:
atten_mask应传入矩阵示意如下:
Masked QKT矩阵示意如下,在第二个batch对query进行切分,key和value不切分,4x6的mask矩阵被切分成2x6和2x6的mask,分别在卡1和卡2上计算:
Masked QKT矩阵示意如下,在第二个batch对query进行切分,key和value不切分,5x4的mask矩阵被切分成2x4和3x4的mask,分别在卡1和卡2上计算:
共7个输出
(Tensor, Tensor, Tensor, Tensor, int, int, int)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import math import unittest import numpy as np import torch import torch_npu from torch_npu.testing.testcase import TestCase, run_tests from torch_npu.testing.common_utils import SupportedDevices class TestNPUFlashAttention(TestCase): def supported_op_exec(self, query, key, value, atten_mask): scale = 0.08838 qk = torch.matmul(query, key.transpose(2, 3)).mul(scale) qk = qk + atten_mask * (-10000.0) softmax_res = torch.nn.functional.softmax(qk, dim=-1) attention_out = torch.matmul(softmax_res, value) return attention_out def custom_op_exec(self, query, key, value, sparse_params): scale = 0.08838 atten_mask = None if sparse_params[0] == 0: shape = [1, 8, 256, 256] atten_mask_u = np.triu(np.ones(shape), k=sparse_params[1] + 1) atten_mask_l = np.tril(np.ones(shape), k=-sparse_params[2] - 1) atten_masks = atten_mask_u + atten_mask_l atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu() if sparse_params[0] == 2 or sparse_params[0] == 3 or sparse_params[0] == 4: atten_masks = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)) atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu() return torch_npu.npu_fusion_attention( query, key, value, head_num=8, input_layout="BNSD", scale=scale, sparse_mode=sparse_params[0], atten_mask=atten_mask, pre_tockens=sparse_params[1], next_tockens=sparse_params[2]) def get_atten_mask(self, sparse_mode=0, pre_tokens=65536, next_tokens=65536): atten_masks = [] shape = [1, 8, 256, 256] if sparse_mode == 0: atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1) atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1) atten_masks = atten_mask_u + atten_mask_l elif sparse_mode == 1: atten_masks = np.zeros(shape) pre_tokens = 65536 next_tokens = 65536 elif sparse_mode == 2: atten_masks = np.triu(np.ones(shape), k=1) elif sparse_mode == 3: atten_masks = np.triu(np.ones(shape), k=1) elif sparse_mode == 4: atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1) atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1) atten_masks = atten_mask_u + atten_mask_l atten_mask = torch.tensor(atten_masks).to(torch.float16) return atten_mask # sparse_params = [sparse_mode, pre_tokens, next_tokens] # Prec and prec16 indicate the precision comparison standards for float32 and float16 respectively. # In this example, 0.01 is used as the standard. You can change the value as required. def check_result(self, query, key, value, sparse_params): atten_mask = self.get_atten_mask(sparse_params[0], sparse_params[1], sparse_params[2]) output = self.supported_op_exec(query.float(), key.float(), value.float(), atten_mask) fa_result = self.custom_op_exec(query.npu(), key.npu(), value.npu(), sparse_params) self.assertRtolEqual(output.half(), fa_result[0], prec=0.01, prec16=0.01) def test_npu_flash_attention(self, device="npu"): query = torch.randn(1, 8, 256, 256, dtype=torch.float16) key = torch.randn(1, 8, 256, 256, dtype=torch.float16) value = torch.randn(1, 8, 256, 256, dtype=torch.float16) # sparse_params: [sparse_mode, pre_tokens, next_tokens] sparse_params_list = [ [0, 128, 128], [1, 65536, 65536], [2, 65536, 0], [3, 65536, 0], [4, 128, 128] ] for sparse_params in sparse_params_list: self.check_result(query, key, value, sparse_params) if __name__ == "__main__": run_tests() |