# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from six import iteritems
from contextlib import contextmanager
from collections import OrderedDict
import google.protobuf.text_format as text_format
import numpy
import os
import shutil
import tempfile
import zipfile
import nnabla as nn
from nnabla.logger import logger
import nnabla.utils.nnabla_pb2 as nnabla_pb2
from nnabla.utils.get_file_handle import get_parameter_file_loader, load_files, FileHandlerContext
from nnabla.utils.get_file_handle import get_file_handle_save, get_parameter_file_savers, save_files
# TODO temporary work around to suppress FutureWarning message.
import warnings
warnings.simplefilter('ignore', category=FutureWarning)
current_scope = OrderedDict()
root_scope = current_scope
[docs]def get_current_parameter_scope():
"""Returns current parameter scope.
"""
global current_scope
return current_scope
[docs]@contextmanager
def parameter_scope(name, scope=None):
"""
Grouping parameters registered by parametric functions
listed in :mod:`nnabla.parametric_functions`.
Args:
name (str): Parameter scope name.
scope (OrderedDict, optional):
Specify current parameter scope as a local dictionary.
The default value is ``None``. In this case,
the current parameter scope maintained in global is used.
Example:
.. code-block:: python
import nnabla as nn
import nnabla.parametric_functions as PF
import nnabla.functions as F
with nn.parameter_scope('conv1'):
conv_out1 = PF.convolution(x, 32, (5, 5))
bn_out1 = PF.batch_normalization(conv_out1)
act_out1 = F.relu(bn_out1)
with nn.parameter_scope('conv2'):
conv_out2 = PF.convolution(act_out1, 64, (3, 3))
bn_out2 = PF.batch_normalization(conv_out2)
act_out2 = F.relu(bn_out2)
Nesting `with` blocks allows you to nest parameter scopes.
This can also be done by using "/" inside the parameter names.
Example:
.. code-block:: python
with nn.parameter_scope('network1'):
with nn.parameter_scope('conv1'):
conv_out1 = PF.convolution(x, 32, (5, 5))
bn_out1 = PF.batch_normalization(conv_out1)
act_out1 = F.relu(bn_out1)
with nn.parameter_scope('conv2'):
conv_out2 = PF.convolution(act_out1, 64, (3, 3))
bn_out2 = PF.batch_normalization(conv_out2)
act_out2 = F.relu(bn_out2)
is equivalent to
.. code-block:: python
with nn.parameter_scope('network1/conv1'):
conv_out1 = PF.convolution(x, 32, (5, 5))
bn_out1 = PF.batch_normalization(conv_out1)
act_out1 = F.relu(bn_out1)
with nn.parameter_scope('network1/conv2'):
conv_out2 = PF.convolution(act_out1, 64, (3, 3))
bn_out2 = PF.batch_normalization(conv_out2)
act_out2 = F.relu(bn_out2)
"""
global current_scope
names = name.strip('/').split('/')
if not names:
raise ValueError(
'Invalid argument of parameter_scope("{}").'.format(name))
prev_scope = current_scope
if scope is None:
scope = current_scope
else:
if not isinstance(scope, dict):
raise ValueError(
'Scope must be a dictionary. {} is given.'.format(type(scope)))
for name in names:
parent_scope = scope
# When name is empty, the given scope is used as a current scope.
if name:
# Creates a new scope dict if it doesn't exist.
# `dict.get` returns default value (OrderedDict())
# if scope contains `name`
scope = scope.get(name, OrderedDict())
assert isinstance(scope, dict)
parent_scope[name] = scope
current_scope = scope
try:
yield current_scope
finally:
current_scope = prev_scope
def get_parameter(key):
names = key.split('/')
if len(names) > 1:
with parameter_scope(names[0]):
return get_parameter('/'.join(names[1:]))
global current_scope
param = current_scope.get(key, None)
if param is not None:
assert isinstance(param, nn.Variable)
return param
def pop_parameter(key):
"""Remove and get parameter by key.
Args:
key(str): Key of parameter.
Returns: ~nnabla.Variable
Parameter if key found, otherwise None.
"""
names = key.split('/')
if len(names) > 1:
with parameter_scope(names[0]):
return pop_parameter('/'.join(names[1:]))
global current_scope
param = current_scope.get(key, None)
if param is not None:
del current_scope[key]
return param
def set_parameter(key, param):
names = key.split('/')
if len(names) > 1:
with parameter_scope(names[0]):
return set_parameter('/'.join(names[1:]), param)
global current_scope
current_scope[names[0]] = param
def _create_parameter_by_initializer(initializer, shape, need_grad):
# If initializer is not set, just returns a new variable with zeros.
if initializer is None:
assert shape is not None
param = nn.Variable(shape, need_grad=need_grad)
param.data.zero() # Initialize with zero.
return param
# Initialize by a numpy array.
if isinstance(initializer, numpy.ndarray): # numpy init
assert (shape is None) or (tuple(shape) == initializer.shape)
return nn.Variable.from_numpy_array(
initializer, need_grad=need_grad)
# Initialize by Initializer or callable object which takes shape as an argument.
if callable(initializer):
assert shape is not None
return nn.Variable.from_numpy_array(
initializer(shape=list(map(int, shape))), need_grad=need_grad)
# Invalid initialzier argument.
raise ValueError(
"`initializer` must be either the :obj:`numpy.ndarray`"
" or an instance inherited from `nnabla.initializer.BaseInitializer`.")
[docs]def get_parameter_or_create(name, shape=None, initializer=None, need_grad=True,
as_need_grad=None):
"""
Returns an existing parameter variable in current parameter scope
with the provided name.
If a variable with the provided name does not exist,
a new variable is created and registered to the current parameter scope
with the name, then returned.
Args:
name(str):
The name under the current scope. If it already exists, the name
is queried from the parameter manager.
shape (:obj:`tuple` of :obj:`int`):
Shape of created parameter. The shape of the specified
parameter must match with this shape. The default is None which is
only valid if initializer is given as an :obj:`numpy.ndarray`.
initializer (:obj:`nnabla.initializer.BaseInitializer` or :obj:`numpy.ndarray`):
An initialization function to be applied to the parameter.
:obj:`numpy.ndarray` can also be given to initialize parameters
from numpy array data.
need_grad (bool):
Register the parameter with the specified ``need_grad`` flag.
The default is True. If the flag is different from the previously
specified one, the flag will be overwritten, but the values will be
kept.
as_need_grad (bool):
Get a parameter variable with the specified ``need_grad`` flag.
Note that this doesn't overwrite the flag of the registered parameter
variable with the provided name. Instead, if the given flag
mismatches with the previously registered ``need_grad`` flag, it
returns a new variable referring to the same array contents but with
``need_grad=as_need_grad``.
Note:
It returns a `Variable` which is unlinked from the
registered one in the current parmeter scope
(using :py:meth:`nnabla.Variable.get_unlinked_variable`).
That means changing a `need_grad` attribute doesn't affect
the variable existing in the current parameter scope.
"""
# Resolve delimiter '/' in parameter name.
names = name.split('/')
if len(names) > 1:
with parameter_scope(names[0]):
return get_parameter_or_create('/'.join(names[1:]), shape, initializer, need_grad, as_need_grad)
# Set need_grad if as_need_grad is not specified.
if as_need_grad is None:
as_need_grad = need_grad
# Try to find a existing parameter.
param = get_parameter(names[0])
# If found, verify shape and flags, and returns it.
if param is not None:
if param.shape != tuple(shape):
raise ValueError(
'The size of existing parameter "{}" {} is different from the '
'size of new parameter {}.\n'
'To clear all parameters, call nn.clear_parameters().'.format(
name, param.shape, tuple(shape)))
if need_grad != param.need_grad:
param.need_grad = need_grad
set_parameter(name, param)
return param.get_unlinked_variable(need_grad=as_need_grad)
class VariableInfo:
pass
info = VariableInfo()
info.initializer = initializer
# Create a new parameter using specified configuration,
# and write it to current scope..
param = _create_parameter_by_initializer(initializer, shape, need_grad)
param.info = info
set_parameter(name, param)
return param.get_unlinked_variable(need_grad=as_need_grad)
[docs]def get_parameters(params=None, path='', grad_only=True):
"""Get parameter Variables under the current parameter scope.
Args:
params (dict): Internal use. User doesn't set it manually.
path (str): Internal use. User doesn't set it manually.
grad_only (bool): Retrieve all parameters under the current scope if
False, while only parameters with need_grad=True are retrieved
if True.
Returns:
dict: {:obj:`str` : :obj:`~nnabla.Variable`}
"""
global current_scope
if params is None:
params = OrderedDict()
for k, v in iteritems(current_scope):
if isinstance(v, dict):
with parameter_scope(k):
params = get_parameters(
params, '/'.join([path, k]) if path else k, grad_only=grad_only)
else:
assert isinstance(v, nn.Variable)
if not grad_only or v.need_grad:
params['/'.join([path, k]) if path else k] = v
return params
[docs]def clear_parameters():
"""Clear all parameters in the current scope."""
global current_scope
for key in list(current_scope.keys()):
del current_scope[key]
def set_parameter_from_proto(proto):
for parameter in proto.parameter:
var = get_parameter_or_create(
parameter.variable_name, parameter.shape.dim,
need_grad=parameter.need_grad)
param = numpy.reshape(parameter.data, parameter.shape.dim)
var.d = param
[docs]def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
"""Load parameters from a file with the specified format.
Args:
path : path or file object
"""
if isinstance(path, str):
_, ext = os.path.splitext(path)
else:
ext = extension
ctx = FileHandlerContext()
if proto is None:
ctx.proto = nnabla_pb2.NNablaProtoBuf()
else:
ctx.proto = proto
ctx.needs_proto = needs_proto
# Get parameter file loaders
file_loaders = get_parameter_file_loader()
load_files(ctx, file_loaders, path, ext)
return ctx.proto
[docs]def save_parameters(path, params=None, extension=None):
"""Save all parameters into a file with the specified format.
Currently hdf5 and protobuf formats are supported.
Args:
path : path or file object
params (dict, optional): Parameters to be saved. Dictionary is of a parameter name (:obj:`str`) to :obj:`~nnabla.Variable`.
"""
if isinstance(path, str):
_, ext = os.path.splitext(path)
else:
ext = extension
ctx = FileHandlerContext()
ctx.parameters = get_parameters(
grad_only=False) if params is None else params
file_savers = get_parameter_file_savers()
supported = save_files(ctx, file_savers, path, ext)
assert supported, 'Only supported {}.'.format(
','.join(list(file_savers.keys())))
logger.info("Parameter save ({}): {}".format(ext, path))