混合精度计算分析

尽管低精度计算能够提供非常显著的训练速度提升,但低精度运算相比FP32也会引起数值错误和数值稳定性问题。所以使用低精度运算来加速训练的时候,理解张量运算核心的数值行为也非常重要。

FP16的取值范围为±65504,能表示的绝对值最小的规格数为2-14和-2-14,即为+/-0.000061035,相邻两个浮点数的相对差距约为千分之一。举例来说,比1大的第一个浮点数为1+1/1024,这意味着没法表示1.0001。使用FP16相对FP32会有更大精度误差。下面的代码在GPU上分别用FP32和FP16计算一个的10000维随机张量(中位数0,方差1的正态分布)均值的相对误差。用FP16计算时,平均相对误差在0.1%,最大误差在23%。在大模型训练场景下,因为很多张量都很小,且维度高,所以需留心FP16的精度问题。同样的代码如果使用BF16计算,相对误差均值在1%,最大误差接近4800%。因为累加运算有不确定性,程序运行两次会有不同的结果。

import torch
import argparse
parser = argparse.ArgumentParser(
                    prog='test_sum',
                    description='check the effect to final result under different floating point precision')
parser.add_argument("-t", "--dtype", default="fp16")
parser.add_argument("-f", "--fun", default="sum")
args = parser.parse_args()
device = torch.device("cuda:0")
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
fn = torch.mean if args.fun == "sum" else torch.mean
def test(fn):
    N = 10000
    batch = 1000
    x = torch.randn((batch, N))
    x_16 = x.to(device, dtype=dtype)
 
    x_mean_16 = fn(x_16, dim=-1, dtype=dtype)
    x_fp32 = x.to(device)
    x_mean_fp32 = fn(x_fp32, dim=-1, dtype=torch.float32)
    rel_diff = abs((x_mean_16.float() - x_mean_fp32) / x_mean_fp32)
    return rel_diff
rel_diff= test(fn=fn)
max_rel_diff = rel_diff.max()
avg_rel_diff = rel_diff.mean()
print("average relative difference is: ", avg_rel_diff)
print("max relative difference is: ", max_rel_diff)

因为FP16浮点表示范围很小,所以很容易发生上溢,运算结果产生Inf/NaN。某些硬件会将Inf转成最大可表示数,一定程度缓解上溢,但仍然会带来计算误差。

BF16通过增加指数位宽,降低尾数位宽获得更大的浮点数可表示范围,同时也牺牲了可表示精度。它的取值范围为±9.2E38,能表示的绝对值最小的规格数是±2-126,约为1.18*10-38,相邻两个浮点数的相对差距在百分之一。举例来说,比1大的第一个浮点数为1+1/128,这意味着没法表示1.001。BF16的精度更差,但不易发生上溢下溢,在精度不敏感的运算场景比FP16更方便。 现有的大型transformer模型实验证明,使用BF16进行混合精度训练不太影响模型的收敛。

分布式通信算子有累加操作时,也会有低精度数的误差累积问题,如下代码所示,8 GPU的FP16 all_reduce均值误差大约在千分之一,最大误差在2倍。而用BF16时,均值相对误差大约在1.5%,最大相对误差有9倍。因为通信算子的确定性计算问题,同样的输入运行两次都会有不同的结果。

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import argparse
parser = argparse.ArgumentParser(
                    prog='fp precison test',
                    description='check the effect to final result under different floating point precision')
parser.add_argument("-t", "--dtype", default="fp16")
args = parser.parse_args()
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device('cuda:{}'.format(local_rank))# 'cuda' or 'npu'
torch.npu.set_device(device)
dist.init_process_group(backend='nccl', rank=local_rank, world_size=8) #'nccl' for cuda or "hccl" for ascend
if __name__ == "__main__":
    print("start all reduce!")
    hccl_input = torch.randn(10000)
    print(local_rank, hccl_input.shape)
    hccl_input16 = hccl_input.to(dtype=dtype).npu()
    hccl_input32 = hccl_input.npu()
    torch.distributed.all_reduce(hccl_input16)
    torch.distributed.all_reduce(hccl_input32)
    if local_rank == 0:
        rel_diff = abs((hccl_input16.float() - hccl_input32) / hccl_input32)
        diff_mean = torch.mean(rel_diff)
        diff_max = torch.max(rel_diff)
        print("diff_mean=", diff_mean)
        print("max diff=", diff_max)

总结而言,选择何种精度进行混合精度训练需依据具体任务需求权衡精度损失与训练速度提升,同时考虑分布式环境下的误差累积效应。