Parametric Functions

In NNabla, trainable models are created by composing functions that have optimizable parameters. These functions are called parametric functions. Parametric functions are provided by nnabla.parametric_functions.

See also:
Python API Tutorial.

Parameter Management API

The parameters registered by List of Parametric Functions can be managed using APIs listed in this section.

nnabla.parameter.parameter_scope(name)[source]

Grouping parameters registered by parametric functions listed in nnabla.parametric_functions.

Example:

import nnabla as nn
import nnabla.parametric_functions as PF
import nnabla.functions as F

with nn.parameter_scope('conv1'):
    conv_out1 = PF.convolution(x, 32, (5, 5))
    bn_out1 = PF.batch_normalization(conv_out1)
    act_out1 = F.relu(bn_out1)
with nn.parameter_scope('conv2'):
    conv_out2 = PF.convolution(act_out1, 64, (3, 3))
    bn_out2 = PF.batch_normalization(conv_out2)
    act_out2 = F.relu(bn_out2)
nnabla.parameter.get_parameters(params=None, path='', grad_only=True)[source]

Get parameter Variables under the current parameter scope.

Parameters:
  • params (dict) – Inernal use. User doesn’t set it manually.
  • path (str) – Internal use. User doesn’t set it manually.
  • grad_only (bool) – Retrieve all parameters under the current scope if False, while only parameters with need_grad=True are retrieved if True.
Returns:

{str : Variable}

Return type:

dict

nnabla.parameter.clear_parameters()[source]

Clear all parameters in the current scope.

nnabla.parameter.save_parameters(path, format='hdf5')[source]

Save all parameters into a file with the specified format.

Currently hdf5 and protobuf formats are supported.

Parameters:path – path or file object
nnabla.parameter.load_parameters(path, proto=None)[source]

Load parameters from a file with the specified format.

Parameters:path – path or file object
nnabla.parameter.get_parameter_or_create(name, shape, initializer=None, need_grad=True)[source]

Returns an existing parameter variable with the provided name. If a variable with the provided name does not exist, a new variable with the provided name is returned.

Parameters:
  • name (str) – The name under the current scope. If it already exists, the name is queried from the parameter manager.
  • shape (tuple of int) – Shape of created parameter. The shape of the specified parameter must match with this shape.
  • initializer (BaseInitializer) – An initialization function to be applied to the parameter.
  • need_grad (bool) – The value for need_grad . The default is True.

List of Parametric Functions

Parametric functions are provided by nnabla.parametric_functions , as listed below. Like functions listed in List of Functions, they take Variable (s) as first argument(s) followed by options specific to a parametric function. In addition, they register parameter Variable (s) into the parameter scope.

All parametric functions listed below are decorated with the following decorator.

nnabla.parametric_functions.parametric_function_api(scope_name=None)[source]

Decorator for parametric functions.

The decorated function is always called under a parameter scope scope_name. Also, the decorator adds an additional argument name (str, default is None) at the end. If name is specified, the scope scope_name comes under a scope name. This feature could reduce vertical space usage of the source code. Any parametric function should be decoreated by this.

Parameters:scope_name (str, optional) – The original function will be called under a parameter scope named by scope_name.
Returns:A decorated parametric function.
Return type:function

See Parameter Management API to know how to query and manipulate registered variables.

Here is the list of parametric functions.

nnabla.parametric_functions.affine(inp, n_outmaps, base_axis=1, w_init=None, b_init=None, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

The affine layer, also known as the fully connected layer. Computes

\[{\mathbf y} = {\mathbf A} {\mathbf x} + {\mathbf b}.\]

where \({\mathbf x}, {\mathbf y}\) are the inputs and outputs respectively, and \({\mathbf A}, {\mathbf b}\) are constants.

Parameters:
  • inp (Variable) – Input N-D array with shape (\(M_0 \times \ldots \times M_{B-1} \times D_B \times \ldots \times D_N\)). Dimensions before and after base_axis are flattened as if it is a matrix.
  • n_outmaps (int or tuple of int) – Number of output neurons per data.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • w_init (BaseInitializer) – Initializer for weight.
  • b_init (BaseInitializer) – Initializer for bias.
  • fix_parameters (bool) – When set to True, the weights and biases will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
  • with_bias (bool) – Specify whether to include the bias term.
Returns:

\((B + 1)\)-D array. (\(M_0 \times \ldots \times M_{B-1} \times L\))f

Return type:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = affine(<args>)
nnabla.parametric_functions.convolution(inp, outmaps, kernel, pad=None, stride=None, dilation=None, group=1, w_init=None, b_init=None, base_axis=1, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

N-D Convolution with a bias term.

For Dilated Convolution (a.k.a. Atrous Convolusion), refer to:

Parameters:
  • inp (Variable) – N-D array.
  • outmaps (int) – Number of convolution kernels (which is equal to the number of output channels). For example, to apply convolution on an input with 16 types of filters, specify 16.
  • kernel (tuple of int) – Convolution kernel size. For example, to apply convolution on an image with a 3 (height) by 5 (width) two-dimensional kernel, specify (3,5).
  • pad (tuple of int) – Padding sizes for dimensions.
  • stride (tuple of int) – Stride sizes for dimensions.
  • dilation (tuple of int) – Dilation sizes for dimensions.
  • group (int) – Number of groups of channels. This makes connections across channels more sparse by grouping connections along map direction.
  • w_init (BaseInitializer) – Initializer for weight.
  • b_init (BaseInitializer) – Initializer for bias.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • fix_parameters (bool) – When set to True, the weights and biases will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
  • with_bias (bool) – Specify whether to include the bias term.
Returns:

N-D array.

Return type:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = convolution(<args>)
nnabla.parametric_functions.deconvolution(inp, outmaps, kernel, pad=None, stride=None, dilation=None, group=1, w_init=None, b_init=None, base_axis=1, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

Deconvolution layer.

Parameters:
  • inp (Variable) – N-D array.
  • outmaps (int) – Number of deconvolution kernels (which is equal to the number of output channels). For example, to apply deconvolution on an input with 16 types of filters, specify 16.
  • kernel (tuple of int) – Convolution kernel size. For example, to apply deconvolution on an image with a 3 (height) by 5 (width) two-dimensional kernel, specify (3,5).
  • pad (tuple of int) – Padding sizes for dimensions.
  • stride (tuple of int) – Stride sizes for dimensions.
  • dilation (tuple of int) – Dilation sizes for dimensions.
  • group (int) – Number of groups of channels. This makes connections across channels sparser by grouping connections along map direction.
  • w_init (BaseInitializer) – Initializer for weight.
  • b_init (BaseInitializer) – Initializer for bias.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • fix_parameters (bool) – When set to True, the weights and biases will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
  • with_bias (bool) – Specify whether to include the bias term.
Returns:

N-D array.

Return type:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = deconvolution(<args>)
nnabla.parametric_functions.batch_normalization(inp, axes=[1], decay_rate=0.9, eps=1e-05, batch_stat=True, output_stat=False, name=None)[source]

Batch normalization layer.

\[\begin{split}\begin{array}{lcl} \mu &=& \frac{1}{M} \sum x_i\\ \sigma^2 &=& \frac{1}{M} \left(\sum x_i - \mu\right)^2\\ \hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i &=& \hat{x}_i \gamma + \beta. \end{array}\end{split}\]

where \(x_i, y_i\) are the inputs. In testing, the mean and variance computed by moving average calculated during training are used.

Parameters:
  • inp (Variable) – N-D array of input.
  • axes (tuple of int) – Axes mean and variance are taken.
  • decay_rate (float) – Decay rate of running mean and variance.
  • eps (float) – Tiny value to avoid zero division by std.
  • batch_stat (bool) – Use mini-batch statistics rather than running ones.
  • output_stat (bool) – Output batch mean and variance.
Returns:

N-D array.

Return type:

Variable

References

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = batch_normalization(<args>)
nnabla.parametric_functions.embed(inp, n_inputs, n_features, name=None)[source]

Embed.

Embed slices a matrix/tensor with indexing array/tensor

Parameters:
  • x (Variable) – [Integer] Indices with shape \((I_0, ..., I_N)\)
  • n_inputs – number of possible inputs, words or vocabraries
  • n_features – number of embedding features
Returns:

Output with shape \((I_0, ..., I_N, W_1, ..., W_M)\)

Return type:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = embed(<args>)
nnabla.parametric_functions.prelu(inp, base_axis=1, shared=True, name=None)[source]

Parametrized Rectified Linear Unit function defined as

\[y_i = \max(0, x_i) + w_i \min(0, -x_i)\]

where nagative slope \(w\) is learned and can vary accros channels (an axis specified with base_axis).

Parameters:
  • x (Variable) – N-D array as input
  • base_axis (int) – Dimensions up to base_axis is treated as sample dimension.
  • shared (bool) – Use shared weight value or not
Returns:

N-D array.

Return type:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = prelu(<args>)
nnabla.parametric_functions.binary_connect_affine(inp, n_outmaps, base_axis=1, w_init=None, wb_init=None, b_init=None, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

Binary Connect Affine, multiplier-less inner-product.

Binary Connect Affine is an affine function, except the definition of the inner product is modified. The input-output relation of this function is as follows:

\[y_i = \sum_{i} sign(w_i) x_i.\]

Therefore \(sign(w_i)\) is either \(1\) or \(-1\) and the inner product simplifies to addition.

This function should be used together with Batch Normalization.

References

M. Courbariaux, Y. Bengio, and J.-P. David. “BinaryConnect: Training Deep Neural Networks with binary weights during propagations.” Advances in Neural Information Processing Systems. 2015.

Note

1) if you would like to share weights between some layers, please make sure to share the standard, floating value weights (weight) and not the binarized weights (binary_weight)

2) The weights and the binary weights become synced only after forward() is called, and not after a call to backward(). To access the parameters of the network, remember to call forward() once before doing so, otherwise the float weights and the binary weights will not be in sync.

3) CPU and GPU implementations now use float value for binary_weight, since this function is only for simulation purposes.

Parameters:
  • inp (Variable) – Input N-D array with shape (\(M_0 \times \ldots \times M_{B-1} \times D_B \times \ldots \times D_N\)). Dimensions before and after base_axis are flattened as if it is a matrix.
  • n_outmaps (int or tuple of int) – Number of output neurons per data.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • w_init (BaseInitializer) – Initializer for weight.
  • wb_init (BaseInitializer) – Initializer for binary weight.
  • b_init (BaseInitializer) – Initializer for bias.
  • fix_parameters (bool) – When set to True, the weights and biases will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
Returns:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = binary_connect_affine(<args>)
nnabla.parametric_functions.binary_connect_convolution(inp, outmaps, kernel, pad=None, stride=None, dilation=None, group=1, w_init=None, wb_init=None, b_init=None, base_axis=1, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

Binary Connect Convolution, multiplier-less inner-product.

Binary Connect Convolution is the convolution function, except the definition of the inner product is modified. The input-output relation of this function is as follows:

\[y_{n, a, b} = \sum_{m} \sum_{i} \sum_{j} sign(w_{n, m, i, j}) x_{m, a + i, b + j}.\]

Therefore \(sign(w_i)\) is either \(1\) or \(-1\) and the inner product simplifies to addition.

This function should be used together with BatchNormalization.

References

M. Courbariaux, Y. Bengio, and J.-P. David. “BinaryConnect: Training Deep Neural Networks with binary weights during propagations.” Advances in Neural Information Processing Systems. 2015.

Note

1) if you would like to share weights between some layers, please make sure to share the standard, floating value weights (weight) and not the binarized weights (binary_weight)

2) The weights and the binary weights become synced only after forward() is called, and not after a call to backward(). To access the parameters of the network, remember to call forward() once before doing so, otherwise the float weights and the binary weights will not be in sync.

3) CPU and GPU implementations now use float value for binary_weight, since this function is only for simulation purposes.

Parameters:
  • inp (Variable) – N-D array.
  • outmaps (int) – Number of convolution kernels (which is equal to the number of output channels). For example, to apply convolution on an input with 16 types of filters, specify 16.
  • kernel (tuple of int) – Convolution kernel size. For example, to apply convolution on an image with a 3 (height) by 5 (width) two-dimensional kernel, specify (3,5).
  • pad (tuple of int) – Padding sizes for dimensions.
  • stride (tuple of int) – Stride sizes for dimensions.
  • dilation (tuple of int) – Dilation sizes for dimensions.
  • group (int) – Number of groups of channels. This makes connections across channels sparser by grouping connections along map direction.
  • w_init (BaseInitializer) – Initializer for weight.
  • wb_init (BaseInitializer) – Initializer for binary weight.
  • b_init (BaseInitializer) – Initializer for bias.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • fix_parameters (bool) – When set to True, the weights and biases will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
  • with_bias (bool) – Specify whether to include the bias term.
Returns:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = binary_connect_convolution(<args>)
nnabla.parametric_functions.binary_weight_affine(inp, n_outmaps, base_axis=1, w_init=None, wb_init=None, b_init=None, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

Binary Weight Affine, multiplier-less inner-product with a scale factor.

Binary Weight Affine is the affine function, but the inner product in this function is the following,

\[y_j = \frac{1}{\|\mathbf{w}_j\|_{\ell_1}} \sum_{i} sign(w_{ji}) x_i\]

Therefore \(sign(w_{ji})\) is either \(1\) or \(-1\) and the inner product simplifies to addition followed by scaling factor \(\alpha = \frac{1}{\|\mathbf{w}_j\|_{\ell_1}}\). The number of :\(\alpha\) is the outmaps of the affine function.

References

Rastegari, Mohammad, et al. “XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks.” arXiv preprint arXiv:1603.05279 (2016).

Note

1) if you would like to share weights between some layers, please make sure to share the standard, floating value weights (weight) and not the binarized weights (binary_weight)

2) The weights and the binary weights become synced only after forward() is called, and not after a call to backward(). To access the parameters of the network, remember to call forward() once before doing so, otherwise the float weights and the binary weights will not be in sync.

3) CPU and GPU implementations now use float value for binary_weight, since this function is only for simulation purposes.

Parameters:
  • inp (Variable) – Input N-D array with shape (\(M_0 \times \ldots \times M_{B-1} \times D_B \times \ldots \times D_N\)). Dimensions before and after base_axis are flattened as if it was a matrix.
  • n_outmaps (int or tuple of int) – Number of output neurons per data.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • w_init (BaseInitializer) – Initializer for the weight.
  • wb_init (BaseInitializer) – Initializer for the binary weight.
  • b_init (BaseInitializer) – Initializer for the bias.
  • fix_parameters (bool) – When set to True, the weight and bias will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
  • with_bias (bool) – Specify whether to include the bias term.
Returns:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = binary_weight_affine(<args>)
nnabla.parametric_functions.binary_weight_convolution(inp, outmaps, kernel, pad=None, stride=None, dilation=None, group=1, w_init=None, wb_init=None, b_init=None, base_axis=1, fix_parameters=False, rng=None, with_bias=True, name=None)[source]

Binary Weight Convolution, multiplier-less inner-product with a scale factor.

Binary Weight Convolution is the convolution function, but the inner product in this function is the following,

\[y_{n, a, b} = \frac{1}{\|\mathbf{w}_n\|_{\ell_1}} \sum_{m} \sum_{i} \sum_{j} sign(w_{n, m, i, j}) x_{m, a + i, b + j}.\]

Therefore \(sign(w_{n, m, i, j})\) is either \(1\) or \(-1\) and the inner product simplifies to addition followed by scaling factor \(\alpha = \frac{1}{\|\mathbf{w}_n\|_{\ell_1}}\). The number of \(n\) is the number of outmaps of the convolution function.

References

Rastegari, Mohammad, et al. “XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks.” arXiv preprint arXiv:1603.05279 (2016).

Note

1) if you would like to share weights between some layers, please make sure to share the standard, floating value weights (weight) and not the binarized weights (binary_weight)

2) The weights and the binary weights become synced only after forward() is called, and not after a call to backward(). To access the parameters of the network, remember to call forward() once before doing so, otherwise the float weights and the binary weights will not be in sync.

3) CPU and GPU implementations now use float value for binary_weight, since this function is only for simulation purposes.

Parameters:
  • inp (Variable) – N-D array.
  • outmaps (int) – Number of convolution kernels (which is equal to the number of output channels). For example, to apply convolution on an input with 16 types of filters, specify 16.
  • kernel (tuple of int) – Convolution kernel size. For example, to apply convolution on an image with a 3 (height) by 5 (width) two-dimensional kernel, specify (3,5).
  • pad (tuple of int) – Padding sizes for dimensions.
  • stride (tuple of int) – Stride sizes for dimensions.
  • dilation (tuple of int) – Dilation sizes for dimensions.
  • group (int) – Number of groups of channels. This makes connections across channels sparser by grouping connections along map direction.
  • w_init (BaseInitializer) – Initializer for weight.
  • wb_init (BaseInitializer) – Initializer for binary weight.
  • b_init (BaseInitializer) – Initializer for bias.
  • base_axis (int) – Dimensions up to base_axis are treated as the sample dimensions.
  • fix_parameters (bool) – When set to True, the weights and biases will not be updated.
  • rng (numpy.random.RandomState) – Random generator for Initializer.
  • with_bias (bool) – Specify whether to include the bias term.
Returns:

Variable

Note

If the name option is passed, the parameters become wrapped inside the parameter scope with the specified name, yielding the same results as the following code. This can be used to simplify the code.

with parametric_scope(name):
    output = binary_weight_convolution(<args>)

Parameter Initializer

Some of the parametric functions optionally takes parameter initializer listed below.

class nnabla.initializer.BaseInitializer[source]

Base class of the parameter initializer.

__call__(shape)[source]

Generates an array with an initializer.

Parameters:shape (tuple of int) – numpy.ndarray with the shape created.
Returns:Array.
Return type:numpy.ndarray

Note

Subclasses of BaseInitializer must override this method.

class nnabla.initializer.ConstantInitializer(value=0)[source]

Bases: nnabla.initializer.BaseInitializer

Generates a constant valued array.

Parameters:value (float) – A constant value.
class nnabla.initializer.NormalInitializer(sigma=1.0, rng=None)[source]

Bases: nnabla.initializer.BaseInitializer

Generates a random array from a specified normal distribution.

\[\mathbf x \sim {\cal N} (\mathbf 0 | \sigma^2 \mathbf I)\]
Parameters:
class nnabla.initializer.UniformInitializer(lim=(-1, 1), rng=None)[source]

Bases: nnabla.initializer.BaseInitializer

Generates a random array from a specified uniform distribution.

\[\mathbf x \sim {\cal U} (a, b)\]
Parameters:
nnabla.initializer.calc_normal_std_he_forward(inmaps, outmaps, kernel=(1, 1))[source]

Calculates the standard deviation proposed by He et al.

\[\sigma = \sqrt{\frac{2}{NK}}\]
Parameters:
  • inmaps (int) – Map size of an input Variable, \(N\).
  • outmaps (int) – Map size of an output Variable, \(M\).
  • kernel (tuple of int) – Convolution kernel spatial shape. In above definition, \(K\) is the product of shape dimensions. In Affine, the default value should be used.

References

nnabla.initializer.calc_normal_std_he_backward(inmaps, outmaps, kernel=(1, 1))[source]

Calculates the standard deviation of He et al. (backward case).

\[\sigma = \sqrt{\frac{2}{MK}}\]
Parameters:
  • inmaps (int) – Map size of an input Variable, \(N\).
  • outmaps (int) – Map size of an output Variable, \(M\).
  • kernel (tuple of int) – Convolution kernel spatial shape. In above definition, \(K\) is the product of shape dimensions. In Affine, the default value should be used.

References

nnabla.initializer.calc_normal_std_glorot(inmaps, outmaps, kernel=(1, 1))[source]

Calculates the standard deviation proposed by Glorot et al.

\[\sigma = \sqrt{\frac{2}{NK + M}}\]
Parameters:
  • inmaps (int) – Map size of an input Variable, \(N\).
  • outmaps (int) – Map size of an output Variable, \(M\).
  • kernel (tuple of int) – Convolution kernel spatial shape. In above definition, \(K\) is the product of shape dimensions. In Affine, the default value should be used.

References

nnabla.initializer.calc_uniform_lim_glorot(inmaps, outmaps, kernel=(1, 1))[source]

Calculates the lower bound and the upper bound of the uniform distribution proposed by Glorot et al.

\[\begin{split}b &= \sqrt{\frac{6}{NK + M}}\\ a &= -b\end{split}\]
Parameters:
  • inmaps (int) – Map size of an input Variable, \(N\).
  • outmaps (int) – Map size of an output Variable, \(M\).
  • kernel (tuple of int) – Convolution kernel spatial shape. In above definition, \(K\) is the product of shape dimensions. In Affine, the default value should be used.

References