class nbla::RandomChoice

template<typename T>
class RandomChoice : public nbla::BaseFunction<const vector<int>&, bool, int>

Generate random samples from population `x` with selection probabilities determined by the relative weights `w`.

The number of samples to draw is given by the product of `shape`s dimensions, and the samples are returned with the given `shape`. By default, samples are drawn with replacement, i.e. selection of a specific population member is solely determined by its associated weight. Sampling without replacement, where any population member may be drawn only once, is used if `replace` is set to False.

For both `x` and `w` the innermost dimension corresponds to the individual populations and their weights from which samples are returned with the requested `shape` following all outermost dimensions of the input.

import nnabla as nn
import nnabla.functions as F
import numpy as np
nn.set_auto_forward(True)

# x holds two populations
x = nn.Variable.from_numpy_array(np.array([[11, 22, 33], [110, 220, 330]]))
# w holds the weights for each population
w = nn.Variable.from_numpy_array(np.array([[10, 20, 70], [70, 20, 10]]))

# draw one sample from each population
y = F.random_choice(x, w)  # y.shape => (2, 1)

# draw 12 samples with shape (3, 4) from each population
y = F.random_choice(x, w, shape=(3, 4))  # y.shape => (2, 3, 4)

Note that weights must not be less than zero and for each population the sum of weights must be greater than zero. Additionally, sampling without replacement requires that the number of non-zero weights is not less than the number of samples to be drawn. These conditions are verified in “cpu” computation context but not when using “cuda” or “cudnn” acceleration (this would require additional device synchronization steps penalizing performance).

Random sampling from an implicit array of index values (like categorical or multinomial) can be realized with input `x` constructed as indices.

w = nn.Variable.from_numpy_array(np.array([1, 2, 3, 2, 1]))
y = F.random_choice(F.arange(0, 5), w)

Inputs:

  • x: N-D array from which a random sample is generated.

  • w: N-D array of associated weights of elements in `x`.

Outputs:

  • N-D array

Template Parameters:

T – Data type for computation.

Param shape:

Number and shape of generated samples.

Param replace:

Whether sampling is with or without replacement.

Param seed:

Random seed.

Public Functions

inline virtual shared_ptr<Function> copy() const

Copy another instance of Function with the same context.

inline virtual int min_inputs()

Get minimum number of inputs.

This is meant to be used in setup function with in_types which is used to get maximum number of inputs.

inline virtual int min_outputs()

Get minimum number of outputs.

This is meant to be used in setup function with out_types which is used to get max number of outputs.

inline virtual vector<dtypes> in_types()

Get input dtypes.

Last in_type will be used repeatedly if size of in_types is smaller than size of inputs

inline virtual vector<dtypes> out_types()

Get output dtypes.

Last out_type will be used repeatedly if size of out_types is smaller than size of outputs

inline virtual vector<string> allowed_array_classes()

Get array classes that are allowed to be specified by Context.

inline virtual string name()

Get function name in string.

inline virtual bool need_setup_recompute(int o) const

A flag for checking if setup_recompute() is needed.

Checking if o-th output’ data requires setup_recompute().

Note

setup_recompute() will skipped during forward execution if none of outputs requires setup_recompute().

Parameters:

o[in] Output variable index.

inline virtual bool grad_depends_output_data(int i, int o) const

Dependency flag for checking if in-grad depends on out-data.

Checking if i-th input’ gradient computation requires o-th output’s data or not.

Note

If any of inputs requires an output variable data when computing its gradient, this function must be overridden to return appropriate boolean value. Otherwise, backward computation will be incorrect.

Parameters:
  • i[in] Input variable index.

  • o[in] Output variable index.