NNP の保存やロードのユーティリティ
重要なお知らせ: Neural Network Console で NPP ファイルを扱うには、保存する / 読み込むネットワークが LoopControl
関数である RepeatStart, RepeatEnd, RecurrentInput, RecurrentOutput または Delay を含む場合、 ファイルフォーマットコンバーター を使ってネットワークを拡張する必要があります。
- nnabla.utils.save.save(filename, contents, include_params=False, variable_batch_size=True, extension='.nnp', parameters=None, include_solver_state=False, solver_state_format='.h5')[ソース]
ネットワーク定義の保存、推論/学習実行の設定等。
- パラメータ:
filename (str or file object) -- 情報を保存するためのファイル名。ファイルの拡張子は、保存するファイルフォーマットを決定します。
.nnp
: (推奨) nntxt (ネットワーク定義等) と h5 (パラメータ) で ZIP圧縮ファイルを生成します。.nntxt
: テキスト形式の Protobuf 。.protobuf
: バイナリ形式の Protobuf (下位互換性の点で安全でありません)contents (dict) -- 保存する情報。
include_params (bool) -- 単一ファイルにパラメータを含めます。ファイル名の拡張子が nnp の場合、無視されます。
variable_batch_size (bool) --
True
の場合、すべての変数の最初の次元は batch size、残りは placeholder (より具体的には-1
) とみなされます。 placeholder の次元は読み込み中、あるいは読み込み後に埋められます。extension -- if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
include_solver_state (bool) -- Indicate whether to save solver state or not.
solver_state_format (str) -- '.h5' or '.protobuf', default '.h5', indicate in which format will solver state be saved, notice that this option only works when save network definition in .nnp format and include_solver_state is True.
サンプル
次の例では、2 つの入力と 2 つの出力 MLP を作成し、ネットワーク構造と初期化されたパラメータを保存します。
import nnabla as nn import nnabla.functions as F import nnabla.parametric_functions as PF from nnabla.utils.save import save batch_size = 16 x0 = nn.Variable([batch_size, 100]) x1 = nn.Variable([batch_size, 100]) h1_0 = PF.affine(x0, 100, name='affine1_0') h1_1 = PF.affine(x1, 100, name='affine1_0') h1 = F.tanh(h1_0 + h1_1) h2 = F.tanh(PF.affine(h1, 50, name='affine2')) y0 = PF.affine(h2, 10, name='affiney_0') y1 = PF.affine(h2, 10, name='affiney_1') contents = { 'networks': [ {'name': 'net1', 'batch_size': batch_size, 'outputs': {'y0': y0, 'y1': y1}, 'names': {'x0': x0, 'x1': x1}}], 'executors': [ {'name': 'runtime', 'network': 'net1', 'data': ['x0', 'x1'], 'output': ['y0', 'y1']}]} save('net.nnp', contents)
学習可能なモデルを取得するには、以下のコードを代わりにお使いください。
contents = { 'global_config': {'default_context': ctx}, 'training_config': {'max_epoch': args.max_epoch, 'iter_per_epoch': args_added.iter_per_epoch, 'save_best': True}, 'networks': [ {'name': 'training', 'batch_size': args.batch_size, 'outputs': {'loss': loss_t}, 'names': {'x': x, 'y': t, 'loss': loss_t}}, {'name': 'validation', 'batch_size': args.batch_size, 'outputs': {'loss': loss_v}, 'names': {'x': x, 'y': t, 'loss': loss_v}}], 'optimizers': [ {'name': 'optimizer', 'solver': solver, 'network': 'training', 'dataset': 'mnist_training', 'weight_decay': 0, 'lr_decay': 1, 'lr_decay_interval': 1, 'update_interval': 1}], 'datasets': [ {'name': 'mnist_training', 'uri': 'MNIST_TRAINING', 'cache_dir': args.cache_dir + '/mnist_training.cache/', 'variables': {'x': x, 'y': t}, 'shuffle': True, 'batch_size': args.batch_size, 'no_image_normalization': True}, {'name': 'mnist_validation', 'uri': 'MNIST_VALIDATION', 'cache_dir': args.cache_dir + '/mnist_test.cache/', 'variables': {'x': x, 'y': t}, 'shuffle': False, 'batch_size': args.batch_size, 'no_image_normalization': True }], 'monitors': [ {'name': 'training_loss', 'network': 'validation', 'dataset': 'mnist_training'}, {'name': 'validation_loss', 'network': 'validation', 'dataset': 'mnist_validation'}], }
- class nnabla.utils.nnp_graph.NnpLoader(filepath, scope=None, extension='.nntxt')[ソース]
NNP ファイルローダー。
- パラメータ:
filepath -- file-like object or filepath.
extension -- if filepath is file-like object, extension is one of ".nnp", ".nntxt", ".prototxt".
サンプル
from nnabla.utils.nnp_graph import NnpLoader # Read a .nnp file. nnp = NnpLoader('/path/to/nnp.nnp') # Assume a graph `graph_a` is in the nnp file. net = nnp.get_network(network_name, batch_size=1) # `x` is an input of the graph. x = net.inputs['x'] # 'y' is an outputs of the graph. y = net.outputs['y'] # Set random data as input and perform forward prop. x.d = np.random.randn(*x.shape) y.forward(clear_buffer=True) print('output:', y.d)