class nbla::ChannelFirstAdaptor

class ChannelFirstAdaptor

This class can be used to transform a variable memory format to channel-first.

Typical use case is a transformation from channel-last to channel-first memory format in convolutional neural network.

For example, let an input variable shape (32, 128, 128, 3), batch_axis == [0] and channel_axis == 3 (channel-last). Then this adaptor converts the shape and memory format to (32, 3, 128, 128) by applying transpose function.

Possible use case is a force use of channel-first implementation (layer functions) in a channel-last network by sandwiching channel-first implementation with this adaptor. The conceptual work flow (forward prop) is following. chennal-last layer -> forward_pre -> channel-first layer -> forward_post -> channel-last layer (Variebles are omitted and the same for backward prop)

Public Functions

NBLA_API void setup (Variable *input_pre, Variable *output_pre, Variable *input_post, Variable *output_post, const Shape_t &shape, const vector< int > &batch_axis, const int channel_axis, const Context &ctx)

Setup the adaptor.

This method must be called before the use of `forward_(pre|post)` or `backward_(pre|post)` methods. Setting up of transpose functions (`(pre|post)_transpose_`) are performed internally.

パラメータ:
  • input_pre -- Variable pointer for `pre_transpose_` input.

  • output_pre -- Variable pointer for `pre_transpose_` output.

  • input_post -- Variable pointer for `post_transpose_` input.

  • output_post -- Variable pointer for `post_transpose_` output.

  • shape -- Shape of original input. This should be the same as the shape of `input_pre`.

  • batch_axis -- List of integer corresponding to batch or outer axis. Each axis must be in range of [0, ndim).

  • channel_axis -- An integer corresponding to channel axis. This axis must be in range of [0, ndim).

  • ctx -- A compute backend descriptor.

NBLA_API void convert_to_channel_first (Variable *input, Variable *output)

Transform variable memory format to channel-first.

パラメータ:
  • input -- `pre_tranpose_` input.

  • output -- `pre_tranpose_` output.

NBLA_API void convert_from_channel_first (Variable *input, Variable *output)

Transform variable memory format from channel-first to original one.

パラメータ:
  • input -- `post_tranpose_` input.

  • output -- `post_tranpose_` output.

NBLA_API void convert_to_channel_first_backward (Variable *input, Variable *output, const bool propagate_down, const bool accum)

Backward execution for `forward_pre`.

パラメータ:
  • input -- `pre_tranpose_` input.

  • output -- `pre_tranpose_` output.

  • propagate_down -- Flag whether or not to perform backward propagation.

  • accum -- Flag whether or not to accumulate grad.

NBLA_API void convert_from_channel_first_backward (Variable *input, Variable *output, const bool propagate_down, const bool accum)

Backward execution for `forward_post`.

パラメータ:
  • input -- `post_tranpose_` input.

  • output -- `post_tranpose_` output.

  • propagate_down -- Flag whether or not to perform backward propagation.

  • accum -- Flag whether or not to accumulate grad.

Public Static Functions

static NBLA_API bool need_adaptor (const Shape_t &shape, const vector< int > &batch_axis, const int channel_axis)

Check wheather this adaptor is needed for the input.

Returns `false` when the input is already channel-first format, otherwise `true`. The definition of channel-first memory format is when all of the following conditions are met. * `batch_axis` must be contiguous to the beginning of the shape. * `channel_axis` must be right after the last axis of `batch_axis`.

Examples of channel-first * ndim: 4, batch_axis: [0], channel_axis: 1 * ndim: 2, batch_axis: [0], channel_axis: 1 * ndim: 2, batch_axis: [], channel_axis: 0 * ndim: 1, batch_axis: [], channel_axis: 0 * ndim: 6, batch_axis: [0, 1, 2], channel_axis: 3 * ndim: 6, batch_axis: [1, 0, 2], channel_axis: 3

Examples of non channel-first * ndim: 4, batch_axis: [0], channel_axis: 3 * ndim: 4, batch_axis: [0], channel_axis: 2 * ndim: 4, batch_axis: [1], channel_axis: 0 * ndim: 2, batch_axis: [], channel_axis: 1 * ndim: 6, batch_axis: [0, 1, 3], channel_axis: 2

パラメータ:
  • shape -- Shape of original input.

  • batch_axis -- List of integer corresponding to batch or outer axis. Each axis must be unique and in range of [0, ndim).

  • channel_axis -- An integer corresponding to channel axis. This axis must be in range of [0, ndim).

戻り値:

true

戻り値:

false