Source code for nnabla.utils.inspection.value_trace

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from collections import OrderedDict
from contextlib import contextmanager

import nnabla.functions as F

from .base import FunctionHookCallbackBase


def _number_to_order(n):
    if n == 3:
        return "3rd"
    elif n == 2:
        return "2nd"
    elif n == 1:
        return "1st"
    else:
        return "{}th".format(n)


def _error_trace(history, exec_name):
    print("Error during {} propagation".format(exec_name))
    for i in range(len(history) - 1):
        print("\t{}".format(history[i]))

    print("\t{} <-- ERROR".format(history[-1]))


[docs]class NanInfTracer(FunctionHookCallbackBase): """ An utility API to create function_hook callbacks to check whether the outputs of all layers have NaN or inf as their values. During forward and backward execution, passed as function_hook, this API reports ValueError if at least one of all layer outputs has Nan or inf as its values. Otherwise, all tensors passed to next layer or function as is. Example: .. code-block:: python pred = model(...) from nnabla.utils.inspection import NanInfTracer nit = NanInfTracer(trace_inf=True, trace_nan=True, need_details=True) with nit.trace(): pred.forward(function_post_hook=nit.forward_post_hook) pred.backward(function_post_hook=nit.backward_post_hook) """ def __init__(self, trace_nan=True, trace_inf=True, need_details=True): super(NanInfTracer, self).__init__() self.trace_nan = trace_nan self.trace_inf = trace_inf self.need_details = need_details self.key_to_stat_fwd = OrderedDict() self.key_to_stat_bwd = OrderedDict() _msg_general = "The {} output of the function '{}' (rank: {}) has nan or inf as its values." _msg_detail = """ Function details: function type: {} shapes of inputs: {} shapes of outputs: {} function args: {} """ if self.need_details else "" self._msg = _msg_general + _msg_detail self._msg_keys = ["name", "rank", "function_type", "input_shapes", "output_shapes", "function_args"] def _add_key(self, f, key_to_stat): if f not in key_to_stat: key_to_stat[f] = { "name": f.name, "function_type": None, "function_args": None, "input_shapes": None, "output_shapes": None, } if self.need_details: key_to_stat[f].update({ "function_type": f.info.type_name, "function_args": str(f.info.args), "input_shapes": [str(x.shape) for x in f.inputs], "output_shapes": [str(x.shape) for x in f.outputs], }) @property def pre_hook(self): # Perform nothing for pre_hook. return None @property def post_hook(self): raise NotImplementedError("NanInfTracer has no member named post_hook. " "Please use forward_post_hook and backward_post_hook for forward and backward respectively instead.") @property def forward_post_hook(self): """ Create callback function object which can be used as a function_post_hook argument of forward(). """ if not (self.trace_nan or self.trace_inf): return None # Perform F.isnan and then F.sum to check the output of incoming function contains the nan value. def callback(f): # For the first time to check this function. self._add_key(f, self.key_to_stat_fwd) # apply callback to check the outputs of this function has nan values or not. nan = [] if self.trace_nan: nan = [F.sum(F.isnan(o.data)) for o in f.outputs] inf = [] if self.trace_inf: inf = [F.sum(F.isinf(o.data)) for o in f.outputs] self.key_to_stat_fwd[f].update({ "inf": inf, "nan": nan, # rank might be changed between each iteration. "rank": f.rank, }) return callback @property def backward_post_hook(self): """ Create callback function object which can be used as a function_post_hook argument of backward(). """ if not (self.trace_nan or self.trace_inf): return None # Perform F.isnan and then F.sum to check the output of incoming function contains the nan value. def callback(f): # For the first time to check this function. self._add_key(f, self.key_to_stat_bwd) # apply callback to check the outputs of this function has nan values or not. nan = [] if self.trace_nan: nan = [F.sum(F.isnan(i.grad)) for i in f.inputs] inf = [] if self.trace_inf: inf = [F.sum(F.isinf(i.grad)) for i in f.inputs] self.key_to_stat_bwd[f].update({ "inf": inf, "nan": nan, # rank might be changed between each iteration. "rank": f.rank, }) return callback
[docs] @contextmanager def trace(self): """ Create context manager to check nan/inf existence by using with statement. Using this context manager, checking nan/inf is performed automatically just before exiting with scope. Unless you use this context manager, be sure to call .check() explicitly to check nan/inf. Example: .. code-block:: python nit = NanInfTracer() with nit.trace(): pred.forward(function_post_hook=nit.forward_post_hook) pred.backward(function_post_hook=nit.backward_post_hook) """ # No need to release any resources at the end of this context manager yield self self.check()
def _check_impl(self, key_to_stat, exec_name): history = [] index = None for f, stat in key_to_stat.items(): # check nan value if self.trace_nan: for i, v in enumerate(stat["nan"]): if v.data > 0: index = i break if index is None and self.trace_inf: for i, v in enumerate(stat["inf"]): if v.data > 0: index = i break history.append(f) if index is not None: _error_trace(history, exec_name) raise ValueError(self._msg.format(_number_to_order(index), *[stat[x] for x in self._msg_keys]))
[docs] def check(self): """ Checks nan/inf existence at all outputs of all layers and raises ValueError only if exist. """ # check forward self._check_impl(self.key_to_stat_fwd, "forward") # check backward self._check_impl(self.key_to_stat_bwd, "backward")