# Mixed Precision Training¶

## Introduction¶

Traditionally, for training a neural network, we used to use `FP32`

for weights and activations; however computation costs for trainig a
neural network rapidly increase over years as the success of deep
learning and the growing size of a neural nework. It indiates that we
need to spend much more time for training a huge size of a neural
network while we would like to do lots of trials before a product
launch. To address this problem, companys (e.g., NVIDIA) introduced an
accelarator for speeding up computation. For example, NVIDIA Volta has
Tensor
Cores
to speed up computation.

However, it uses `FP16`

weights, activations, gradients, and the range
of `FP16`

is very limited when compared to that of `FP32`

, meaning
that sometimes (or often) values of gradients overflow and/or underflow,
which affects the performance of a neural network or makes it collapse
during training.

Mixed precision training is one of the algorithms to circumvent that
problem while maintaining the same results that we could obtain with
`FP32`

networks. It is well-described in The Training with Mixed
Precision User
Guide
and Mixed Precision Training.

This tutorial explains how to do the mixed precision training in NNabla step-by-step.

## Step-by-Step Instruction¶

Basically, the mixed precision training are composed of three parts.

- Use the accelarator for computation (here we assume Tensor Cores)
- Use loss scaling to prevent underflow
- Use dynamic loss caling to prevent overflow/underflow

In NNabla, we can do the correspondinces as follows.

### 1. Use Tensor Cores¶

```
ctx = get_extension_context("cudnn", type_config="half")
```

### 2. Use loss scaling to prevent underflow¶

```
loss_scale = 8
loss.backward(loss_scale)
solver.scale_grad(1. / loss_scale) # do some graident clipping, etc. after this
solver.update()
```

### 3. Use dynamic loss scaling to prevent overflow/underflow¶

```
loss_scale = 8
scaling_factor = 2
counter = 0
interval = 2000
...
loss.backward(loss_scale, ...)
...
if solver.check_inf_or_nan_grad():
loss_scale /= scaling_factor
counter = 0
else:
solver.scale_grad(1. / loss_scale) # do some graident clipping, etc. after this
solver.update()
if counter > interval:
loss_scale *= scaling_factor
counter = 0
counter += 1
```

**Note** that currently the procedures of 2nd (Use loss scaling to
prevent underflow) and 3rd (Use loss scaling to prevent overflow) are
exprimental, and we are now trying to speed up the mixed precision
training, so API might change for future use, especially 3rd.

## All-in-one Instruction¶

In the previous step-by-step example, the 3rd step is lengthy in a training loop, thus we can write a wrapper class like the following.

```
class AutoLossScalingUpdater(object):
def __init___(self, solver, loss,
scale=8, scaling_factor=2, N=2000, clear_buffer=True,
accum_grad=1,
comm=None,
grads=[]):
self.solver = solver
self.loss = loss
self.scale = scale
self.N = N
self.clear_buffer = clear_buffer
self.accum_grad = accum_grad
self.scaling_factor = scaling_factor
self.comm = comm
self.grads = grads
self.counter = 0
def update(self):
# Forward and backward
for _ in range(self.accum_grad):
# forward
self.loss.forward(clear_no_need_grad=self.clear_buffer)
# backward with scale
self.loss.backward(self.scale, clear_buffer=self.clear_buffer)
# AllReduce
if self.comm and len(self.grads) != 0:
self.comm.all_reduce(self.grads, division=False, inplace=False)
# Check Inf/NaN in grads
if self.solver.check_inf_or_nan_grad():
self.scale /= self.scaling_factor
self.counter = 0
return
# Rescale grads
self.solver.scale_grad(1. / self.scale)
# Do some graident clipping, etc.
# Update
self.solver.update()
if self.counter > self.N:
self.scale *= self.scaling_factor
self.conter = 0
self.counter += 1
```

Then, call the update method in a training loop:

```
solver = <Solver>
loss = <Loss Variable of Network>
updater = AutoLossScalingUpdater(solver, loss)
...
x.d = <data>
y.d = <label>
...
updater.update()
```

## Notice¶

In the mixed-precision training, the followings are premise:

- Solver contains
`FP16`

weights and the`FP32`

copy of weights. Solvers in NNabla hold`FP32`

weights and weight gradients and cast it to`FP16`

weights in forward pass and to`FP16`

weight gradients in backward pass if one sets`type_config="half"`

. - Reductions should be left in
`FP32`

, for examples, the statistics (mean and variance) computed by the batch-normalization, Mean, Sum, SoftMax, SoftMaxCrossEntropy, etc. (see The Training with Mixed Precision User Guide). In NNabla, these functions are automatically fallbacked to use`FP32`

.