Source code for nnabla.experimental.mixed_precision_training

# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


[docs]class DynamicLossScalingUpdater(object): '''Dynamic Loss Scaling Updater for the mixed precision training. Args: solver (:obj:`nnabla.solvers.Solver`): Solver object. E.g., Momentum or Adam. loss (:obj:`nnabla.Variable`): Loss variable from which the forward and the backward is called. data_feeder (callable :obj:`object`, function, or lambda): Data feeder scale (:obj:`float`): Loss scale constant. This is dynamically changing during training. scaling_factor (:obj:`float`): Scaling factor for the dynamic loss scaling. N (:obj:`int`): Interval, the number of iterations in training for increasing `loss scale` by `scaling_factor`. clear_buffer (:obj:`bool`): Clears the no longer referenced variables during backpropagation to save memory. accum_grad (:obj:`int`): Number of accumulation of gradients. Update method of the `solver` is called after the `accum_grad` number of the forward and backward is called. weight_decay (:obj:`float`): Decay constant. Default is `None`, not applying the weight decay. comm (:obj:`nnabla.communicators.Communicator`): Communicator when to do distributed training. Default is :obj:`None`. grads (:obj:`list` of :obj:`nnabla._nd_array.NdArray`): The list of gradients to be exchanged when to do distributed training. Default is the empty :obj:`list`. Attributes: solver (:obj:`nnabla.solvers.Solver`): Solver object. E.g., Momentum or Adam. loss (:obj:`nnabla.Variable`): Loss variable from which the forward and the backward is called. data_feeder (callable :obj:`object`, function, lambda): Data feeder scale (:obj:`float`): Loss scale constant. This is dynamically changing during training. scaling_factor (:obj:`float`): Scaling factor for the dynamic loss scaling. N (:obj:`int`): Interval, the number of iterations in training for increasing `loss scale` by `scaling_factor`. clear_buffer (:obj:`bool`): Clears the no longer referenced variables during backpropagation to save memory. accum_grad (:obj:`int`): Number of accumulation of gradients. Update method of the `solver` is called after the `accum_grad` number of the forward and backward is called. weight_decay (:obj:`float`): Decay constant. Default is `None`, not applying the weight decay. comm (:obj:`nnabla.communicators.Communicator`): Communicator when to do distributed training. grads (:obj:`list` of :obj:`nnabla._nd_array.NdArray`): The list of gradients to be exchanged when to do distributed training. Example: .. code-block:: python solver = <Solver> loss = <Loss Variable of Network> data_feeder = <DataFeeder> updater = DynamicLossScalingUpdater(solver, loss, data_feeder) # Training iteration for itr in range(max_iter): # Call solver.zero_grad, data_feeder, loss.forward, loss.backward # and solver.update with the dynamic loss scaling. updater.update() Reference: https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#scalefactor ''' def __init__(self, solver, loss, data_feeder=lambda x: x, scale=8.0, scaling_factor=2.0, N=2000, clear_buffer=True, accum_grad=1, weight_decay=None, comm=None, grads=[]): self.solver = solver self.loss = loss self.data_feeder = data_feeder self.scale = scale self.scaling_factor = scaling_factor self.N = N self.clear_buffer = clear_buffer self.accum_grad = accum_grad self.weight_decay = weight_decay self.comm = comm self.grads = grads self._counter = 0 self._recursive_count = 0 self._max_recursive_count = 100
[docs] def update(self): """Monolithic update method. This method calls the following methods with the dynamic loss scaling. 1. solver.zerograd 2. feed data 3. loss.forward 4. loss.backward 5. comm.all_reduce (if it is specified) 6. solver.update """ # Initialize gradients. self.solver.zero_grad() # Forward and backward for _ in range(self.accum_grad): # feed data self.data_feeder() # forward self.loss.forward(clear_no_need_grad=self.clear_buffer) # backward with scale self.loss.backward(self.scale, clear_buffer=self.clear_buffer) # AllReduce if self.comm and len(self.grads) != 0: self.comm.all_reduce(self.grads, division=False, inplace=False) # Check Inf/NaN in grads if self.solver.check_inf_or_nan_grad(): self.scale /= self.scaling_factor self._counter = 0 # Recursively call udpate function until no inf nor nan. self._recursive_count += 1 if self._recursive_count > self._max_recursive_count: self._recursive_count = 0 return # skip return self.update() self._recursive_count = 0 # Rescale grads self.solver.scale_grad(1. / self.scale) # Do some gradient clipping, etc. if self.weight_decay is not None: self.solver.weight_decay(self.weight_decay) # Update self.solver.update() if self._counter > self.N: self.scale *= self.scaling_factor self._counter = 0 self._counter += 1