实现“Transformer Attention Score”的融合计算,实现的计算公式如下:
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=1, int[]? prefix=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False ) -> (Tensor, Tensor, Tensor, Tensor, int, int, int)
sparse_mode |
含义 |
备注 |
---|---|---|
0 |
defaultMask模式 |
- |
1 |
allMask模式 |
- |
2 |
leftUpCausal模式 |
- |
3 |
rightDownCausal模式 |
- |
4 |
band模式 |
- |
atten_mask的工作原理为,在Mask为True的位置遮蔽query(Q)与key(K)的转置矩阵乘积的值,示意如下:
QKT矩阵在atten_mask为Ture的位置会被遮蔽,效果如下:
atten_mask应传入下三角矩阵,示意如下:
atten_mask应传入band形状矩阵,示意如下:
该场景下忽略next_tockens、pre_tockens取值,Masked QKT矩阵示意如下:
传入的atten_mask为优化后的压缩下三角矩阵(2048*2048),压缩下三角矩阵示意(下同):
共7个输出
(Tensor, Tensor, Tensor, Tensor, int, int, int)
Atlas A2训练系列产品
import math import unittest import numpy as np import torch import torch_npu from torch_npu.testing.testcase import TestCase, run_tests from torch_npu.testing.common_utils import get_npu_device DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestNPUFlashAttention(TestCase): def supported_op_exec(self, query, key, value): qk = torch.matmul(query, key.transpose(2, 3)).mul(0.08838) softmax_res = torch.nn.functional.softmax(qk, dim=-1) output = torch.matmul(softmax_res, value) output = output.transpose(1, 2) output = output.reshape(output.shape[0], output.shape[1], -1) return output def custom_op_exec(self, query, key, value): scale = 0.08838 return torch_npu.npu_fusion_attention( query, key, value, head_num=32, input_layout="BSH", scale=scale) def trans_BNSD2BSH(self, tensor: torch.Tensor): tensor = torch.transpose(tensor, 1, 2) tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1)) return tensor def test_npu_flash_attention(self, device="npu"): query = torch.randn(1, 32, 128, 128, dtype=torch.float16) key = torch.randn(1, 32, 128, 128, dtype=torch.float16) value = torch.randn(1, 32, 128, 128, dtype=torch.float16) q_FA = self.trans_BNSD2BSH(query).npu() k_FA = self.trans_BNSD2BSH(key).npu() v_FA = self.trans_BNSD2BSH(value).npu() output = self.supported_op_exec(query.float(), key.float(), value.float()) attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.custom_op_exec(q_FA, k_FA, v_FA) self.assertRtolEqual(output.half(), attention_score, prec=0.005, prec16=0.005) if __name__ == "__main__": run_tests()