Mixed Precision を用いた学習

はじめに

従来、ニューラルネットワークの学習の際には重みと活性値として FP32 ( 32bit 浮動小数点数 ) が用いられてきました。しかし、深層学習の成功とニューラルネットワークのサイズ増加に伴い、近年ニューラルネットワークの学習にかかる計算コストが急速に増えています。これは、製品の発売前に多くの試行錯誤を行うにあたり、巨大なサイズのニューラルネットワークの学習のためにより多くの時間が必要になっていることを示しています。この問題に対処するために、深層学習のためのハードウェアを提供する企業 ( 例 NVIDIA ) は計算を高速化するためのアクセラレーターを導入しました。例えば、NVIDIA の Volta 世代以降の GPU は計算を高速化するために Tensor Cores を備えています。

しかし、 FP16 ( 16bit 浮動小数点数 ) を重み、活性値、勾配に使うにあたり、 FP16 の表現力は FP32 と比較して大きく制限されています。つまり、 勾配の値のオーバーフローやアンダーフローが時々 ( あるいは、頻繁に ) 発生し、これによりニューラルネットワークのパフォーマンスに悪影響を及ぼしたり、学習の失敗を引き起こします。

Mixed Precision を用いた学習は、 FP32 ネットワークで得られるものと同じ結果を維持しながら問題を回避する方法の 1 つです。詳しくは、 混合精度学習のユーザーガイド混合精度学習 に記載しています。

このチュートリアルでは、 NNabla における Mixed Precision を用いた学習方法を段階的に説明します。

段階毎の説明

基本的には、 Mixed Precision を用いた学習は 次の 3 つの段階で構成されています。

  1. 計算のためのアクセラレーターの使用 ( ここでは Tensor Cores を仮定しています )

  2. オーバーフローを防ぐためのロススケーリングの使用

  3. オーバーフロー / アンダーフローを防ぐための動的ロススケーリングの使用

NNabla では、次のように対応できます。

1. Tensor Cores の使用

ctx = get_extension_context("cudnn", type_config="half")

2. アンダーフローを防ぐためのロススケーリングの使用

loss_scale = 8
loss.backward(loss_scale)
solver.scale_grad(1. / loss_scale)  # do some gradient clipping, etc. after this
solver.update()

3. オーバーフロー / アンダーフローを防ぐための動的ロススケーリングの使用

loss_scale = 8
scaling_factor = 2
counter = 0
interval = 2000
...
loss.backward(loss_scale, ...)
...
if solver.check_inf_or_nan_grad():
    loss_scale /= scaling_factor
    counter = 0
else:
    solver.scale_grad(1. / loss_scale) # do some gradient clipping, etc. after this
    solver.update()
    if counter > interval:
        loss_scale *= scaling_factor
        counter = 0
    counter += 1

: 手順の 2 番目 ( アンダーフローを防ぐためのロス・スケーリングの使用 ) と 3 番目 ( オーバーフローを防ぐためのロス・スケーリングの使用 ) は、現在実験段階であり、混合精度学習の高速化に尽力しています。そのため、 API 、特に 3 番目は、将来の使用のため変わる可能性があります。

全ての処理を含んだ説明

前述の段階的な説明では、学習ループ内における 3 番目のステップが非常に長くなっています。代わりにここは次のようなラッパークラスを書くことができます。

class DynamicLossScalingUpdater(object):
    '''Dynamic Loss Scaling Updater for the mixed precision training.

    Args:
        solver (:obj:`nnabla.solvers.Solver`): Solver object. E.g., Momentum or Adam.
        loss (:obj:`nnabla.Variable`): Loss variable from which the forward and the backward is called.
        data_feeder (callable :obj:`object`, function, or lambda): Data feeder
        scale (:obj:`float`): Loss scale constant. This is dynamically changing during training.
        scaling_factor (:obj:`float`): Scaling factor for the dynamic loss scaling.
        N (:obj:`int`): Interval, the number of iterations in training for increasing `loss scale` by `scaling_factor`.
        clear_buffer (:obj:`bool`): Clears the no longer referenced variables during backpropagation to save memory.
        accum_grad (:obj:`int`): Number of accumulation of gradients. Update method of the `solver` is called after the `accum_grad` number of the forward and backward is called.
        weight_decay (:obj:`float`): Decay constant. Default is `None`, not applying the weight decay.
        comm (:obj:`nnabla.communicators.Communicator`): Communicator when to do distributed training. Default is :obj:`None`.
        grads (:obj:`list` of :obj:`nnabla._nd_array.NdArray`): The list of gradients to be exchanged when to do distributed training. Default is the empty :obj:`list`.

    Attributes:
        solver (:obj:`nnabla.solvers.Solver`): Solver object. E.g., Momentum or Adam.
        loss (:obj:`nnabla.Variable`): Loss variable from which the forward and the backward is called.
        data_feeder (callable :obj:`object`, function, lambda): Data feeder
        scale (:obj:`float`): Loss scale constant. This is dynamically changing during training.
        scaling_factor (:obj:`float`): Scaling factor for the dynamic loss scaling.
        N (:obj:`int`): Interval, the number of iterations in training for increasing `loss scale` by `scaling_factor`.
        clear_buffer (:obj:`bool`): Clears the no longer referenced variables during backpropagation to save memory.
        accum_grad (:obj:`int`): Number of accumulation of gradients. Update method of the `solver` is called after the `accum_grad` number of the forward and backward is called.
        weight_decay (:obj:`float`): Decay constant. Default is `None`, not applying the weight decay.
        comm (:obj:`nnabla.communicators.Communicator`): Communicator when to do distributed training.
        grads (:obj:`list` of :obj:`nnabla._nd_array.NdArray`): The list of gradients to be exchanged when to do distributed training.

    Example:

        .. code-block:: python
            solver = <Solver>
            loss = <Loss Variable of Network>
            data_feeder = <DataFeeder>

            updater = DynamicLossScalingUpdater(solver, loss, data_feeder)

            # Training iteration
            for itr in range(max_iter):
                # Call solver.zero_grad, data_feeder, loss.forward, loss.backward
                # and solver.update with the dynamic loss scaling.
                updater.update()

    Reference:

        https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#scalefactor

    '''

    def __init__(self, solver, loss, data_feeder=lambda x: x,
                  scale=8.0, scaling_factor=2.0, N=2000, clear_buffer=True,
                  accum_grad=1, weight_decay=None,
                  comm=None,
                  grads=[]):
        self.solver = solver
        self.loss = loss
        self.data_feeder = data_feeder
        self.scale = scale
        self.scaling_factor = scaling_factor
        self.N = N
        self.clear_buffer = clear_buffer
        self.accum_grad = accum_grad
        self.weight_decay = weight_decay
        self.comm = comm
        self.grads = grads
        self._counter = 0
        self._recursive_count = 0
        self._max_recursive_count = 100

    def update(self):
        """Monolithic update method.

        This method calls the following methods with the dynamic loss scaling.

        1. solver.zerograd
        2. feed data
        3. loss.forward
        4. loss.backward
        5. comm.all_reduce (if it is specified)
        6. solver.update

        """

        # Initialize gradients.
        self.solver.zero_grad()

        # Forward and backward
        for _ in range(self.accum_grad):
            # feed data
            self.data_feeder()

            # forward
            self.loss.forward(clear_no_need_grad=self.clear_buffer)

            # backward with scale
            self.loss.backward(self.scale, clear_buffer=self.clear_buffer)

        # AllReduce
        if self.comm and len(self.grads) != 0:
            self.comm.all_reduce(self.grads, division=False, inplace=False)

        # Check Inf/NaN in grads
        if self.solver.check_inf_or_nan_grad():
            self.scale /= self.scaling_factor
            self._counter = 0

            # Recursively call update function until no inf nor nan.
            self._recursive_count += 1
            if self._recursive_count > self._max_recursive_count:
                self._recursive_count = 0
                return  # skip
            return self.update()
        self._recursive_count = 0

        # Rescale grads
        self.solver.scale_grad(1. / self.scale)

        # Do some gradient clipping, etc.
        if self.weight_decay is not None:
            self.solver.weight_decay(self.weight_decay)

        # Update
        self.solver.update()
        if self._counter > self.N:
            self.scale *= self.scaling_factor
            self._counter = 0
        self._counter += 1

その後、学習・ループ内で更新メソッドを呼びます。

from nnabla.experimental.mixed_precision_training import DynamicLossScalingUpdater

solver = <Solver>
loss = <Loss Variable of Network>
data_feeder = <DataFeeder>

updater = DynamicLossScalingUpdater(solver, loss, data_feeder)

# Training iteration
for itr in range(max_iter):
    # Call solver.zero_grad, data_feeder, loss.forward, loss.backward
    # and solver.update with the dynamic loss scaling.
    updater.update()

注意

Mixed Precision を用いた学習では、次のような動作になります。

  1. Solver は FP16 の重みと FP32 の重みのコピーを保持します。 type_config="half" が指定されると、 NNabla における Solver は FP32 の重みと重みの勾配を保持し、それを順方向パスで FP16 の重みへキャストし、逆方向パスで FP16 の重みの勾配へキャストします。

  2. batch-normalization、 Mean、 Sum、 SoftMax、 SoftMaxCrossEntropy などによって計算される統計量 ( 平均や分散 ) の演算には、 FP32 を使用してください ( 混合精度学習のユーザーガイド を参照してください) 。 NNabla では、 type_config="half" を指定した場合もこれらの関数は 自動的に FP32 を使うようにフォールバックされます。