Source code for nnabla.utils.nnp_graph

# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from collections import OrderedDict
import os
import weakref
import numpy as np
import itertools

import nnabla as nn
import nnabla.function as F


[docs]class NnpNetwork(object): '''A graph object which is read from nnp file. An instance of NnpNetwork is usually created by an NnpLoader instance. See an example usage described in :obj:`NnpLoader`. Attributes: variables (dict): A dict of all variables in a created graph with a variable name as a key, and a nnabla.Variable as a value. inputs (dict): All input variables. outputs (dict): All output variables. ''' def __init__(self, proto_network, batch_size, callback): proto_network = proto_network.expand_loop_control() self.proto_network = proto_network.promote(callback) self.proto_network(batch_size=batch_size) for k, v in itertools.chain( self.proto_network.variables.items(), self.proto_network.parameters.items()): v.variable_instance.name = k self._inputs = { i: self.proto_network.variables[i].variable_instance for i in self.proto_network.inputs } self._outputs = { i: self.proto_network.variables[i].variable_instance for i in self.proto_network.outputs } self._variables = { k: v.variable_instance for k, v in itertools.chain( self.proto_network.variables.items(), self.proto_network.parameters.items()) } # publish network's parameters to current parameter scope # like original implementation. with nn.parameter_scope('', nn.get_current_parameter_scope()): for k, v in self.proto_network.parameters.items(): nn.parameter.set_parameter(k, v.variable_instance) @property def inputs(self): return self._inputs @property def outputs(self): return self._outputs @property def variables(self): return self._variables
[docs]class NnpLoader(object): '''An NNP file loader. Args: filepath : file-like object or filepath. extension: if filepath is file-like object, extension is one of ".nnp", ".nntxt", ".prototxt". Example: .. code-block:: python from nnabla.utils.nnp_graph import NnpLoader # Read a .nnp file. nnp = NnpLoader('/path/to/nnp.nnp') # Assume a graph `graph_a` is in the nnp file. net = nnp.get_network(network_name, batch_size=1) # `x` is an input of the graph. x = net.inputs['x'] # 'y' is an outputs of the graph. y = net.outputs['y'] # Set random data as input and perform forward prop. x.d = np.random.randn(*x.shape) y.forward(clear_buffer=True) print('output:', y.d) ''' def __init__(self, filepath, scope=None, extension=".nntxt"): # OrderedDict maintains loaded parameters from nnp files. # The loaded parameters will be copied to the current # scope when get_network is called. self._params = scope if scope else OrderedDict() self.g = nn.graph_def.load( filepath, parameter_scope=self._params, rng=np.random.RandomState(1223), extension=extension) self.network_dict = { name: pn for name, pn in self.g.networks.items() }
[docs] def get_network_names(self): '''Returns network names available. ''' return list(self.network_dict.keys())
[docs] def get_network(self, name, batch_size=None, callback=None): '''Create a variable graph given network by name Returns: NnpNetwork ''' return NnpNetwork(self.network_dict[name], batch_size, callback=callback)
class NnpNetworkPass(object): def _no_verbose(self, *a, **kw): pass def _verbose(self, *a, **kw): print(*a, **kw) def __init__(self, verbose=0): self._variable_callbacks = {} self._function_callbacks_by_name = {} self._function_callbacks_by_type = {} self._passes_by_name = {} self._passes_by_type = {} self._fix_parameters = False self._use_up_to_variables = set() self.verbose = self._no_verbose self.verbose2 = self._no_verbose if verbose: self.verbose = self._verbose if verbose > 1: self.verbose2 = self._verbose def on_function_pass_by_name(self, name): def _on_function_pass_by_name(callback): def _callback(f, variables, param_scope): return callback(f, variables, param_scope) self._passes_by_name[name] = _callback return _callback return _on_function_pass_by_name def on_function_pass_by_type(self, name): def _on_function_pass_by_type(callback): def _callback(f, variables, param_scope): return callback(f, variables, param_scope) self._passes_by_name[name] = _callback return _callback return _on_function_pass_by_type def on_generate_variable(self, name): def _on_generate_variable(callback): def _callback(v): return callback(v) self._variable_callbacks[name] = _callback return _callback return _on_generate_variable def on_generate_function_by_name(self, name): def _on_generate_function_by_name(callback): def _callback(v): return callback(v) self._function_callbacks_by_name[name] = _callback return _callback return _on_generate_function_by_name def on_generate_function_by_type(self, name): def _on_generate_function_by_type(callback): def _callback(v): return callback(v) self._function_callbacks_by_type[name] = _callback return _callback return _on_generate_function_by_type def drop_function(self, *names): def callback(f, variables, param_scope): self.verbose('Pass: Deleting {}.'.format(f.name)) f.disable() for name in names: self.on_function_pass_by_name(name)(callback) def fix_parameters(self): self._fix_parameters = True def use_up_to(self, *names): self._use_up_to_variables.update(set(names)) def remove_and_rewire(self, name, i=0, o=0): @self.on_function_pass_by_name(name) def on_dr(f, variables, param_scope): fi = f.inputs[i] fo = f.outputs[o] self.verbose('Removing {} and rewire input={} and output={}.'.format( f.name, fi.name, fo.name)) fo.rewire_on(fi) # Use input name fo.proto.name = fi.name def set_variable(self, name, input_var): @self.on_generate_variable(name) def on_input_x(v): self.verbose('Replace {} by {}.'.format(name, input_var)) v.proto.shape.dim[:] = input_var.shape v.variable = input_var input_var.name = v.name return v def force_average_pooling_global(self, name, by_type=False): dec = self.on_generate_function_by_name if by_type: dec = self.on_generate_function_by_type @dec(name) def on_avgpool(f): pool_shape = f.inputs[0].proto.shape.dim[2:] self.verbose('Change strides of {} to {}.'.format( f.name, pool_shape)) p = f.proto.average_pooling_param p.kernel.dim[:] = pool_shape p.stride.dim[:] = pool_shape return f def check_average_pooling_global(self, name, by_type=False): dec = self.on_generate_function_by_name if by_type: dec = self.on_generate_function_by_type @dec(name) def on_avgpool_check(f): pool_shape = f.inputs[0].proto.shape.dim[2:] p = f.proto.average_pooling_param if p.kernel.dim[:] != pool_shape or p.stride.dim[:] != pool_shape: raise ValueError( 'Stride configuration of average pooling is not for global pooling.' ' Given Image shape is {}, whereas pooling window size is {} and its stride is {}.' ' Consider using force_global_pooling=True'.format( pool_shape, p.kernel.dim[:], p.stride.dim[:])) return f def set_batch_normalization_batch_stat_all(self, batch_stat): @self.on_generate_function_by_type('BatchNormalization') def on_bn(f): self.verbose('Setting batch_stat={} at {}.'.format( batch_stat, f.name)) p = f.proto.batch_normalization_param p.batch_stat = batch_stat return f def _apply_function_pass_by_name(self, f, variables, param_scope): if f.name not in self._passes_by_name: return f return self._passes_by_name[f.name](f, variables, param_scope) def _apply_function_pass_by_type(self, f, variables, param_scope): if f.proto.type not in self._passes_by_type: return f return self._passes_by_type[f.proto.type](f, variables, param_scope) def _apply_generate_variable(self, v): if v.name in self._variable_callbacks: v = self._variable_callbacks[v.name](v) if self._fix_parameters: v.need_grad = False return v def _apply_generate_function_by_name(self, f): if f.name not in self._function_callbacks_by_name: return f return self._function_callbacks_by_name[f.name](f) def _apply_generate_function_by_type(self, f): if f.proto.type not in self._function_callbacks_by_type: return f return self._function_callbacks_by_type[f.proto.type](f) def _apply_use_up_to(self, variables): for v in variables: if v.name in self._use_up_to_variables: self.verbose('Stopping at {}.'.format(v.name)) v.stop = True