class nbla::SyncBatchNormalization

template<typename T>
class SyncBatchNormalization : public nbla::BatchNormalization<T>

Batch normalization with sync between other processes at training time defined as.

\[\begin{split} \begin{array}{lcl} \mu &=& \frac{1}{M} \sum x_i\\ \sigma^2 &=& \frac{1}{M} \left(\sum x_i - \mu\right)^2\\ \hat{x}_i &=& \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i &=& \hat{x}_i \gamma + \beta. \end{array} \end{split}\]

Inputs:

  • N-D array of input.

  • N-D array of beta which is learned.

  • N-D array of gamma which is learned.

  • N-D array of running mean (modified during forward execution).

  • N-D array of running variance (modified during forward execution).

Outputs (1 or 3):

  • N-D array.

  • (Optional) N-D array of batch mean.

  • (Optional) N-D array of batch variance.

See also

Implementing Synchronized Multi-GPU Batch Normalization https://hangzhang.org/PyTorch-Encoding/notes/syncbn.html

Template Parameters:

T – Data type for computation.

Param comm:

The communicator

Param group:

The name of the communicator group

Param axes:

Axes mean and variance are taken.

Param decay_rate:

Decay rate of running mean and variance.

Param eps:

Tiny value to avoid zero division by std.

Public Functions

inline virtual shared_ptr<Function> copy() const override

Copy another instance of Function with the same context.

inline virtual string name() override

Get function name in string.

inline virtual bool grad_depends_output_data(int i, int o) const

Dependency flag for checking if in-grad depends on out-data.

Checking if i-th input’ gradient computation requires o-th output’s data or not.

Note

If any of inputs requires an output variable data when computing its gradient, this function must be overridden to return appropriate boolean value. Otherwise, backward computation will be incorrect.

Parameters:
  • i[in] Input variable index.

  • o[in] Output variable index.