下载
中文
注册

混合精度原理与计算过程(AMP)

原理

float16是指采用2字节(16位)进行编码存储的一种数据类型;float32是指采用4字节(32位)存储的数据类型。数据格式如图1所示。

图1 float32与float16数据格式示意图

使用float16代替float32计算有如下优点:

  • 内存占用减少:float16内存占用比float32更小。在模型训练时,可以设置更大的batch_size。
  • 加速计算:float16训练速度可以是float32的2~8倍。

使用float16代替float32也可能带来如下问题:

  • 数值溢出:由于float16的值区间比float32的值区间小很多,所以在计算过程中很容易出现上溢出(Overflow,>65504)和下溢出(Underflow,<6x10^-8),溢出之后就会出现“NaN”的问题。在深度学习中,由于激活函数的梯度往往要比权重梯度小,更易出现下溢出的情况。
  • 舍入误差:舍入误差指的是当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败。

上述问题使用混合精度训练可能带来的问题,可以通过在训练过程中使用动态损失缩放(Loss Scale)来解决。操作思想如下:

  1. 使用混合精度加速:使用float16进行乘法和存储,只使用float32进行加法操作,避免累加误差。
  2. 使用损失缩放(Loss Scale):反向传播前,将损失变化手动增大2k倍,因此反向传播时得到的中间变量(激活函数梯度)则不会溢出。反向传播后,将权重梯度缩2k倍,恢复正常值。

上述两点操作可分别通过使用amp.autocast和amp.GradScaler来实现:

  1. amp.autocast-混合精度加速:AMP会使用amp.autocast自动为算子选择合适的数值精度,用户不需要手动对模型参数的数值类型(dtype)进行转换。

    该操作可以作为Python上下文管理器(Context Manager)和装饰器来使用,用来指定脚本中某个区域或者某些函数按照AMP运行。混合精度是先将float32的模型的参数拷贝一份并转换成float16。AMP规定的float16算子(例如卷积、Matmul)以float16计算。AMP规定的float32算子(例如BatchNormalize、softmax)在输入和输出时精度为float16,计算时为float32,在反向传播时依然是混合精度计算,得到数值精度为float16的梯度。优化器的操作是利用float16的梯度对float32的参数进行更新。

  2. amp.GradScaler损失缩放(Loss Scale)- AMP使用amp.GradScaler提供梯度缩放(scale)操作,缓解反向传播时float16的梯度数值溢出问题。在优化器更新参数前,AMP会自动对梯度缩放进行还原(unscale),对用于模型优化的超参数不会有任何影响。

混合精度计算过程

AMP的基本计算过程如下:

  1. 拷贝并维护一个float32数值精度模型的副本。
  2. 初始化缩放系数(scale factor)。
  3. 精度转换,每个迭代计算一次。
    1. 拷贝参数并转换为float16模型精度。
    2. 正向传播(float16的模型参数)。
    3. Loss乘缩放系数。
    4. 反向传播(float16的模型参数与参数梯度)。
    5. 参数梯度除以缩放系数。
    6. 利用float16的梯度更新float32的模型参数。

计算过程图如下所示:

图2 AMP计算过程图