NNP の保存やロードのユーティリティ

重要なお知らせ: Neural Network Console で NPP ファイルを扱うには、保存する / 読み込むネットワークが LoopControl 関数である RepeatStart, RepeatEnd, RecurrentInput, RecurrentOutput または Delay を含む場合、 ファイルフォーマットコンバーター を使ってネットワークを拡張する必要があります。

nnabla.utils.save.save(filename, contents, include_params=False, variable_batch_size=True, extension='.nnp', parameters=None, include_solver_state=False, solver_state_format='.h5')[ソース]

ネットワーク定義の保存、推論/学習実行の設定等。

パラメータ:
  • filename (str or file object) -- 情報を保存するためのファイル名。ファイルの拡張子は、保存するファイルフォーマットを決定します。.nnp: (推奨) nntxt (ネットワーク定義等) と h5 (パラメータ) で ZIP圧縮ファイルを生成します。.nntxt: テキスト形式の Protobuf 。 .protobuf: バイナリ形式の Protobuf (下位互換性の点で安全でありません)

  • contents (dict) -- 保存する情報。

  • include_params (bool) -- 単一ファイルにパラメータを含めます。ファイル名の拡張子が nnp の場合、無視されます。

  • variable_batch_size (bool) -- True の場合、すべての変数の最初の次元は batch size、残りは placeholder (より具体的には -1) とみなされます。 placeholder の次元は読み込み中、あるいは読み込み後に埋められます。

  • extension -- if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".

  • include_solver_state (bool) -- Indicate whether to save solver state or not.

  • solver_state_format (str) -- '.h5' or '.protobuf', default '.h5', indicate in which format will solver state be saved, notice that this option only works when save network definition in .nnp format and include_solver_state is True.

サンプル

次の例では、2 つの入力と 2 つの出力 MLP を作成し、ネットワーク構造と初期化されたパラメータを保存します。

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
from nnabla.utils.save import save

batch_size = 16
x0 = nn.Variable([batch_size, 100])
x1 = nn.Variable([batch_size, 100])
h1_0 = PF.affine(x0, 100, name='affine1_0')
h1_1 = PF.affine(x1, 100, name='affine1_0')
h1 = F.tanh(h1_0 + h1_1)
h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
y0 = PF.affine(h2, 10, name='affiney_0')
y1 = PF.affine(h2, 10, name='affiney_1')

contents = {
    'networks': [
        {'name': 'net1',
         'batch_size': batch_size,
         'outputs': {'y0': y0, 'y1': y1},
         'names': {'x0': x0, 'x1': x1}}],
    'executors': [
        {'name': 'runtime',
         'network': 'net1',
         'data': ['x0', 'x1'],
         'output': ['y0', 'y1']}]}
save('net.nnp', contents)

学習可能なモデルを取得するには、以下のコードを代わりにお使いください。

contents = {
'global_config': {'default_context': ctx},
'training_config':
    {'max_epoch': args.max_epoch,
     'iter_per_epoch': args_added.iter_per_epoch,
     'save_best': True},
'networks': [
    {'name': 'training',
     'batch_size': args.batch_size,
     'outputs': {'loss': loss_t},
     'names': {'x': x, 'y': t, 'loss': loss_t}},
    {'name': 'validation',
     'batch_size': args.batch_size,
     'outputs': {'loss': loss_v},
     'names': {'x': x, 'y': t, 'loss': loss_v}}],
'optimizers': [
    {'name': 'optimizer',
     'solver': solver,
     'network': 'training',
     'dataset': 'mnist_training',
     'weight_decay': 0,
     'lr_decay': 1,
     'lr_decay_interval': 1,
     'update_interval': 1}],
'datasets': [
    {'name': 'mnist_training',
     'uri': 'MNIST_TRAINING',
     'cache_dir': args.cache_dir + '/mnist_training.cache/',
     'variables': {'x': x, 'y': t},
     'shuffle': True,
     'batch_size': args.batch_size,
     'no_image_normalization': True},
    {'name': 'mnist_validation',
     'uri': 'MNIST_VALIDATION',
     'cache_dir': args.cache_dir + '/mnist_test.cache/',
     'variables': {'x': x, 'y': t},
     'shuffle': False,
     'batch_size': args.batch_size,
     'no_image_normalization': True
     }],
'monitors': [
    {'name': 'training_loss',
     'network': 'validation',
     'dataset': 'mnist_training'},
    {'name': 'validation_loss',
     'network': 'validation',
     'dataset': 'mnist_validation'}],
}
class nnabla.utils.nnp_graph.NnpLoader(filepath, scope=None, extension='.nntxt')[ソース]

NNP ファイルローダー。

パラメータ:
  • filepath -- file-like object or filepath.

  • extension -- if filepath is file-like object, extension is one of ".nnp", ".nntxt", ".prototxt".

サンプル

from nnabla.utils.nnp_graph import NnpLoader

# Read a .nnp file.
nnp = NnpLoader('/path/to/nnp.nnp')
# Assume a graph `graph_a` is in the nnp file.
net = nnp.get_network(network_name, batch_size=1)
# `x` is an input of the graph.
x = net.inputs['x']
# 'y' is an outputs of the graph.
y = net.outputs['y']
# Set random data as input and perform forward prop.
x.d = np.random.randn(*x.shape)
y.forward(clear_buffer=True)
print('output:', y.d)
get_network(name, batch_size=None, callback=None)[ソース]

name で指定されたネットワークの variable graph を作成します。

戻り値: NnpNetwork

get_network_names()[ソース]

利用可能なネットワーク名を返します。

class nnabla.utils.nnp_graph.NnpNetwork(proto_network, batch_size, callback)[ソース]

nnp ファイルから読み込んだ graph object。

NnpNetwork のインスタンスは通常 NnpLoader インスタンスにより作成されます。詳しくは、NnpLoader にある使用例をご覧ください。

variables

変数名をキーとして、nnabla.Variable を値として作成した graph にある全ての変数の辞書。

:

dict

inputs

すべての入力変数。

:

dict

outputs

すべての出力変数。

:

dict