Mixed Precision Training


Traditionally, for training a neural network, we used to use FP32 for weights and activations; however computation costs for trainig a neural network rapidly increase over years as the success of deep learning and the growing size of a neural nework. It indiates that we need to spend much more time for training a huge size of a neural network while we would like to do lots of trials before a product launch. To address this problem, companys (e.g., NVIDIA) introduced an accelarator for speeding up computation. For example, NVIDIA Volta has Tensor Cores to speed up computation.

However, it uses FP16 weights, activations, gradients, and the range of FP16 is very limited when compared to that of FP32, meaning that sometimes (or often) values of gradients overflow and/or underflow, which affects the performance of a neural network or makes it collapse during training.

Mixed precision training is one of the algorithms to circumvent that problem while maintaining the same results that we could obtain with FP32 networks. It is well-described in The Training with Mixed Precision User Guide and Mixed Precision Training.

This tutorial explains how to do the mixed precision training in NNabla step-by-step.

Step-by-Step Instruction

Basically, the mixed precision training are composed of three parts.

  1. Use the accelarator for computation (here we assume Tensor Cores)
  2. Use loss scaling to prevent underflow
  3. Use dynamic loss caling to prevent overflow/underflow

In NNabla, we can do the correspondinces as follows.

1. Use Tensor Cores

ctx = get_extension_context("cudnn", type_config="half")

2. Use loss scaling to prevent underflow

loss_scale = 8
solver.scale_grad(1. / loss_scale)  # do some graident clipping, etc. after this

3. Use dynamic loss scaling to prevent overflow/underflow

loss_scale = 8
scaling_factor = 2
counter = 0
interval = 2000
loss.backward(loss_scale, ...)
if solver.check_inf_or_nan_grad():
    loss_scale /= scaling_factor
    counter = 0
    solver.scale_grad(1. / loss_scale) # do some graident clipping, etc. after this
    if counter > interval:
        loss_scale *= scaling_factor
        counter = 0
    counter += 1

Note that currently the procedures of 2nd (Use loss scaling to prevent underflow) and 3rd (Use loss scaling to prevent overflow) are exprimental, and we are now trying to speed up the mixed precision training, so API might change for future use, especially 3rd.

All-in-one Instruction

In the previous step-by-step example, the 3rd step is lengthy in a training loop, thus we can write a wrapper class like the following.

class AutoLossScalingUpdater(object):

    def __init___(self, solver, loss,
                  scale=8, scaling_factor=2, N=2000, clear_buffer=True,
        self.solver = solver
        self.loss = loss
        self.scale = scale
        self.N = N
        self.clear_buffer = clear_buffer
        self.accum_grad = accum_grad
        self.scaling_factor = scaling_factor
        self.comm = comm
        self.grads = grads
        self.counter = 0

    def update(self):
        # Forward and backward
        for _ in range(self.accum_grad):
            # forward

            # backward with scale
            self.loss.backward(self.scale, clear_buffer=self.clear_buffer)

        # AllReduce
        if self.comm and len(self.grads) != 0:
            self.comm.all_reduce(self.grads, division=False, inplace=False)

        # Check Inf/NaN in grads
        if self.solver.check_inf_or_nan_grad():
            self.scale /= self.scaling_factor
            self.counter = 0

        # Rescale grads
        self.solver.scale_grad(1. / self.scale)

        # Do some graident clipping, etc.

        # Update
        if self.counter > self.N:
            self.scale *= self.scaling_factor
            self.conter = 0
        self.counter += 1

Then, call the update method in a training loop:

solver = <Solver>
loss = <Loss Variable of Network>
updater = AutoLossScalingUpdater(solver, loss)
x.d = <data>
y.d = <label>


In the mixed-precision training, the followings are premise:

  1. Solver contains FP16 weights and the FP32 copy of weights. Solvers in NNabla hold FP32 weights and weight gradients and cast it to FP16 weights in forward pass and to FP16 weight gradients in backward pass if one sets type_config="half".
  2. Reductions should be left in FP32, for examples, the statistics (mean and variance) computed by the batch-normalization, Mean, Sum, SoftMax, SoftMaxCrossEntropy, etc. (see The Training with Mixed Precision User Guide). In NNabla, these functions are automatically fallbacked to use FP32.