# Copyright 2017,2018,2019,2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Save network structure into file.
'''
import os
import re
import types
import nnabla as nn
import numpy
from collections import OrderedDict
from nnabla.logger import logger
from nnabla.parameter import get_parameters
from nnabla.utils import nnabla_pb2
from nnabla.utils.get_file_handle import FileHandlerContext, get_default_file_savers, save_files
# ----------------------------------------------------------------------
# Helper functions
# ----------------------------------------------------------------------
def _create_global_config(ctx):
g = nnabla_pb2.GlobalConfig()
g.default_context.backends[:] = ctx.backend
g.default_context.array_class = ctx.array_class
g.default_context.device_id = ctx.device_id
return g
def _create_training_config(max_epoch, iter_per_epoch, save_best):
t = nnabla_pb2.TrainingConfig()
t.max_epoch = max_epoch
t.iter_per_epoch = iter_per_epoch
t.save_best = save_best
return t
def _format_opti_config_for_states_checkpoint(ctx, contents):
ctx.optimizers = OrderedDict()
for opti_i in contents['optimizers']:
opt = types.SimpleNamespace(**opti_i)
ctx.optimizers[opt.name] = types.SimpleNamespace(optimizer=opt)
def _create_dataset(name, uri, cache_dir, variables, shuffle, batch_size, no_image_normalization):
d = nnabla_pb2.Dataset()
d.name = name
d.uri = uri
if cache_dir is not None:
d.cache_dir = cache_dir
d.shuffle = shuffle
d.batch_size = batch_size
d.variable.extend(variables)
d.no_image_normalization = no_image_normalization
return d
def _create_network(ctx, net, variable_batch_size):
names = dict(net['names'])
names.update(net['outputs'])
g = nn.graph_def.create_graph_from_variable(net['name'], list(net['outputs'].values()), names=names,
parameter_scope=nn.parameter.get_current_parameter_scope())
n = g.default_graph().as_proto(variable_batch_size=variable_batch_size)
n.batch_size = net['batch_size']
ctx.proto_graphs[n.name] = g
return n
def _create_optimizer(ctx, opti_d, save_solver_in_proto):
o = nnabla_pb2.Optimizer()
dataset = None
datasets = ctx.datasets
name = opti_d['name']
solver = opti_d['solver']
# ctx.networks might be missing when optimizer is used in transfer learning
network = ctx.networks[opti_d['network']] if ctx.networks else None
dataset_names = opti_d['dataset']
weight_decay = opti_d['weight_decay']
lr_decay = opti_d['lr_decay']
lr_decay_interval = opti_d['lr_decay_interval']
update_interval = opti_d['update_interval']
o.name = name
o.network_name = network.name if network else b'None'
proto_network = ctx.proto_graphs[opti_d['network']
].default_graph() if network else None
# Allow a list or tuple or a string for dataset names.
if isinstance(dataset_names, tuple):
dataset_names = list(dataset_names)
if isinstance(dataset_names, list):
for dataset_name in dataset_names:
if dataset_name in datasets:
o.dataset_name.append(dataset_name)
dataset = datasets[dataset_name]
else:
raise ValueError(
"Invalid dataset_name is found in optimizer: {}".format(dataset_name))
elif isinstance(dataset_names, str):
dataset_name = dataset_names
if dataset_name in datasets:
o.dataset_name.append(dataset_name)
dataset = datasets[dataset_name]
if dataset is None:
# dataset setting in optimizer might be missing when optimizer is used in transfer learning
# raise ValueError("Dataset is not defined in optimizer.")
pass
o.solver.type = re.sub(r'(|Cuda)$', '', str(solver.name))
if o.solver.type == 'Adadelta':
o.solver.adadelta_param.lr = solver.info['lr']
o.solver.adadelta_param.decay = solver.info['decay']
o.solver.adadelta_param.eps = solver.info['eps']
elif o.solver.type == 'Adagrad':
o.solver.adagrad_param.lr = solver.info['lr']
o.solver.adagrad_param.eps = solver.info['eps']
elif o.solver.type == 'AdaBelief':
o.solver.adabelief_param.alpha = solver.info['alpha']
o.solver.adabelief_param.beta1 = solver.info['beta1']
o.solver.adabelief_param.beta2 = solver.info['beta2']
o.solver.adabelief_param.eps = solver.info['eps']
o.solver.adabelief_param.wd = solver.info['wd']
o.solver.adabelief_param.amsgrad = solver.info['amsgrad']
o.solver.adabelief_param.weight_decouple = solver.info['weight_decouple']
o.solver.adabelief_param.fixed_decay = solver.info['fixed_decay']
o.solver.adabelief_param.rectify = solver.info['rectify']
elif o.solver.type == 'Adam':
o.solver.adam_param.alpha = solver.info['alpha']
o.solver.adam_param.beta1 = solver.info['beta1']
o.solver.adam_param.beta2 = solver.info['beta2']
o.solver.adam_param.eps = solver.info['eps']
elif o.solver.type == 'Adamax':
o.solver.adamax_param.alpha = solver.info['alpha']
o.solver.adamax_param.beta1 = solver.info['beta1']
o.solver.adamax_param.beta2 = solver.info['beta2']
o.solver.adamax_param.eps = solver.info['eps']
elif o.solver.type == 'Momentum':
o.solver.momentum_param.lr = solver.info['lr']
o.solver.momentum_param.momentum = solver.info['momentum']
elif o.solver.type == 'Nesterov':
o.solver.nesterov_param.lr = solver.info['lr']
o.solver.nesterov_param.momentum = solver.info['momentum']
elif o.solver.type == 'RMSprop':
o.solver.rmsprop_param.lr = solver.info['lr']
o.solver.rmsprop_param.decay = solver.info['decay']
o.solver.rmsprop_param.eps = solver.info['eps']
elif o.solver.type == 'RMSpropGraves':
o.solver.rmsprop_graves_param.lr = solver.info['lr']
o.solver.rmsprop_graves_param.decay = solver.info['decay']
o.solver.rmsprop_graves_param.momentum = solver.info['momentum']
o.solver.rmsprop_graves_param.eps = solver.info['eps']
elif o.solver.type == 'Sgd':
o.solver.sgd_param.lr = solver.info['lr']
o.solver.weight_decay = weight_decay
o.solver.lr_decay = lr_decay
o.solver.lr_decay_interval = lr_decay_interval
o.update_interval = update_interval
for var_name, data_name in opti_d.get('data_variables', {}).items():
d = o.data_variable.add()
d.variable_name = var_name
d.data_name = data_name
if proto_network:
for loss_name in opti_d.get('loss_variables', proto_network.outputs):
d = o.loss_variable.add()
d.variable_name = loss_name
solver_params = solver.get_parameters()
network_keys = proto_network.parameters.keys(
) if proto_network else nn.get_parameters().keys()
for param in network_keys:
d = o.parameter_variable.add()
d.variable_name = param
d.learning_rate_multiplier = 1.0 if param in solver_params else 0.0
for g_var in opti_d.get('generator_variables', []):
d = o.generator_variable.add()
d.variable_name = g_var
d.type = 'Constant'
d.multiplier = 0
if save_solver_in_proto:
solver.set_states_to_protobuf(o)
return o
def _create_monitor(ctx, monitor):
datasets = ctx.datasets
if monitor['network'] not in ctx.networks:
raise ValueError(
"{} is not found in networks.".format(monitor['network']))
proto_network = ctx.proto_graphs[monitor['network']].default_graph()
m = nnabla_pb2.Monitor()
m.name = monitor['name']
m.network_name = monitor['network']
if isinstance(monitor['dataset'], (list, tuple)):
for dataset_name in monitor['dataset']:
if dataset_name in datasets:
m.dataset_name.append(dataset_name)
dataset = datasets[dataset_name]
else:
raise ValueError(
"Invalid dataset name is found in monitor definition: {}".format(dataset_name))
elif isinstance(monitor['dataset'], str):
dataset_name = monitor['dataset']
if dataset_name in datasets:
m.dataset_name.append(dataset_name)
dataset = datasets[dataset_name]
if dataset is None:
raise ValueError("Dataset is not defined in monitor definition.")
for var_name, data_name in monitor.get('data_variables', {}).items():
d = m.data_variable.add()
d.variable_name = var_name
d.data_name = data_name
for out in monitor.get('monitor_variables', proto_network.outputs):
d = m.monitor_variable.add()
d.type = 'Error'
d.variable_name = out
for g_var in monitor.get('generator_variables', []):
d = m.generator_variable.add()
d.variable_name = g_var
d.type = 'Constant'
d.multiplier = 0
return m
def _create_executor(ctx, executor):
def save_argument(e, arg_name):
arg = executor.get(arg_name, None)
if arg is not None:
setattr(e, arg_name, arg)
name, network, remap = \
executor['name'], ctx.networks[executor['network']], \
executor.get('remp', {})
generator_variables = executor.get('generator_variables', [])
proto_network = ctx.proto_graphs[executor['network']].default_graph()
e = nnabla_pb2.Executor()
e.name = name
e.network_name = network.name
save_argument(e, "no_image_normalization")
save_argument(e, "num_evaluations")
save_argument(e, "repeat_evaluation_type")
save_argument(e, "need_back_propagation")
for vname in executor.get('data', proto_network.inputs):
if vname not in proto_network.variables:
raise ValueError("{} is not found in networks!".format(vname))
dv = e.data_variable.add()
dv.variable_name = vname
dv.data_name = remap.get(vname, vname)
for vname in executor.get('output', proto_network.outputs):
if vname not in proto_network.variables:
raise ValueError("{} is not found in networks!".format(vname))
ov = e.output_variable.add()
ov.variable_name = vname
ov.data_name = remap.get(vname, vname)
for param in proto_network.parameters.keys():
d = e.parameter_variable.add()
d.variable_name = param
for vname in generator_variables:
d = e.generator_variable.add()
d.type = 'Constant'
d.multiplier = 0
d.variable_name = vname
return e
# ----------------------------------------------------------------------
# Helper functions (END)
# ----------------------------------------------------------------------
def create_proto(contents, include_params=False, variable_batch_size=True, save_solver_in_proto=False):
class Context:
pass
proto = nnabla_pb2.NNablaProtoBuf()
if 'global_config' in contents:
proto.global_config.MergeFrom(
_create_global_config(contents['global_config']['default_context'])
)
if 'training_config' in contents:
proto.training_config.MergeFrom(
_create_training_config(contents['training_config']['max_epoch'],
contents['training_config'][
'iter_per_epoch'],
contents['training_config']['save_best']))
ctx = Context()
ctx.proto_graphs = {}
networks = {}
if 'networks' in contents:
proto_nets = []
for net in contents['networks']:
networks[net['name']] = _create_network(
ctx, net, variable_batch_size)
proto_nets.append(networks[net['name']])
proto.network.extend(proto_nets)
ctx.networks = networks
datasets = {}
if 'datasets' in contents:
proto_datasets = []
for d in contents['datasets']:
if 'cache_dir' in d:
cache_dir = d['cache_dir']
else:
cache_dir = None
datasets[d['name']] = _create_dataset(d['name'],
d['uri'],
cache_dir,
d['variables'],
d['shuffle'],
d['batch_size'],
d['no_image_normalization'])
proto_datasets.append(datasets[d['name']])
proto.dataset.extend(proto_datasets)
ctx.datasets = datasets
if 'optimizers' in contents:
proto_optimizers = []
for o in contents['optimizers']:
proto_optimizers.append(
_create_optimizer(ctx, o, save_solver_in_proto))
proto.optimizer.extend(proto_optimizers)
if 'monitors' in contents:
proto_monitors = []
for m in contents['monitors']:
proto_monitors.append(_create_monitor(ctx, m))
proto.monitor.extend(proto_monitors)
if 'executors' in contents:
proto_executors = []
for e in contents['executors']:
proto_executors.append(
_create_executor(ctx, e))
proto.executor.extend(proto_executors)
if include_params:
params = get_parameters(grad_only=False)
for variable_name, variable in params.items():
parameter = proto.parameter.add()
parameter.variable_name = variable_name
parameter.shape.dim.extend(variable.shape)
parameter.data.extend(numpy.array(variable.d).flatten().tolist())
parameter.need_grad = variable.need_grad
return proto
[docs]def save(filename, contents, include_params=False, variable_batch_size=True, extension=".nnp", parameters=None, include_solver_state=False, solver_state_format='.h5'):
'''Save network definition, inference/training execution
configurations etc.
Args:
filename (str or file object): Filename to store information. The file
extension is used to determine the saving file format.
``.nnp``: (Recommended) Creating a zip archive with nntxt (network
definition etc.) and h5 (parameters).
``.nntxt``: Protobuf in text format.
``.protobuf``: Protobuf in binary format (unsafe in terms of
backward compatibility).
contents (dict): Information to store.
include_params (bool): Includes parameter into single file. This is
ignored when the extension of filename is nnp.
variable_batch_size (bool):
By ``True``, the first dimension of all variables is considered
as batch size, and left as a placeholder
(more specifically ``-1``). The placeholder dimension will be
filled during/after loading.
extension: if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
include_solver_state (bool): Indicate whether to save solver state or not.
solver_state_format (str):
'.h5' or '.protobuf', default '.h5', indicate in which format will solver state be saved,
notice that this option only works when save network definition in .nnp format
and include_solver_state is True.
Example:
The following example creates a two inputs and two
outputs MLP, and save the network structure and the initialized
parameters.
.. code-block:: python
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
from nnabla.utils.save import save
batch_size = 16
x0 = nn.Variable([batch_size, 100])
x1 = nn.Variable([batch_size, 100])
h1_0 = PF.affine(x0, 100, name='affine1_0')
h1_1 = PF.affine(x1, 100, name='affine1_0')
h1 = F.tanh(h1_0 + h1_1)
h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
y0 = PF.affine(h2, 10, name='affiney_0')
y1 = PF.affine(h2, 10, name='affiney_1')
contents = {
'networks': [
{'name': 'net1',
'batch_size': batch_size,
'outputs': {'y0': y0, 'y1': y1},
'names': {'x0': x0, 'x1': x1}}],
'executors': [
{'name': 'runtime',
'network': 'net1',
'data': ['x0', 'x1'],
'output': ['y0', 'y1']}]}
save('net.nnp', contents)
To get a trainable model, use following code instead.
.. code-block:: python
contents = {
'global_config': {'default_context': ctx},
'training_config':
{'max_epoch': args.max_epoch,
'iter_per_epoch': args_added.iter_per_epoch,
'save_best': True},
'networks': [
{'name': 'training',
'batch_size': args.batch_size,
'outputs': {'loss': loss_t},
'names': {'x': x, 'y': t, 'loss': loss_t}},
{'name': 'validation',
'batch_size': args.batch_size,
'outputs': {'loss': loss_v},
'names': {'x': x, 'y': t, 'loss': loss_v}}],
'optimizers': [
{'name': 'optimizer',
'solver': solver,
'network': 'training',
'dataset': 'mnist_training',
'weight_decay': 0,
'lr_decay': 1,
'lr_decay_interval': 1,
'update_interval': 1}],
'datasets': [
{'name': 'mnist_training',
'uri': 'MNIST_TRAINING',
'cache_dir': args.cache_dir + '/mnist_training.cache/',
'variables': {'x': x, 'y': t},
'shuffle': True,
'batch_size': args.batch_size,
'no_image_normalization': True},
{'name': 'mnist_validation',
'uri': 'MNIST_VALIDATION',
'cache_dir': args.cache_dir + '/mnist_test.cache/',
'variables': {'x': x, 'y': t},
'shuffle': False,
'batch_size': args.batch_size,
'no_image_normalization': True
}],
'monitors': [
{'name': 'training_loss',
'network': 'validation',
'dataset': 'mnist_training'},
{'name': 'validation_loss',
'network': 'validation',
'dataset': 'mnist_validation'}],
}
'''
ctx = FileHandlerContext()
ext = extension
if isinstance(filename, str):
_, ext_c = os.path.splitext(filename)
ext = ext_c if ext_c else ext
include_params = False if ext == '.nnp' else include_params
save_solver_in_proto = include_solver_state and ext != '.nnp'
ctx.proto = create_proto(contents, include_params,
variable_batch_size, save_solver_in_proto)
ctx.parameters = parameters
if include_solver_state and ext == '.nnp':
if 'optimizers' not in contents:
raise KeyError('optimizers should be specified in \
contents when include_solver_state is True')
_format_opti_config_for_states_checkpoint(ctx, contents)
file_savers = get_default_file_savers(
solver_state_format=solver_state_format)
else:
file_savers = get_default_file_savers()
save_files(ctx, file_savers, filename, ext)
logger.info("Model file is saved as ({}): {}".format(ext, filename))