class nbla::ChannelFirstAdaptor
-
class ChannelFirstAdaptor
This class can be used to transform a variable memory format to channel-first.
Typical use case is a transformation from channel-last to channel-first memory format in convolutional neural network.
For example, let an input variable shape (32, 128, 128, 3), batch_axis == [0] and channel_axis == 3 (channel-last). Then this adaptor converts the shape and memory format to (32, 3, 128, 128) by applying transpose function.
Possible use case is a force use of channel-first implementation (layer functions) in a channel-last network by sandwiching channel-first implementation with this adaptor. The conceptual work flow (forward prop) is following. chennal-last layer -> forward_pre -> channel-first layer -> forward_post -> channel-last layer (Variebles are omitted and the same for backward prop)
Public Functions
- NBLA_API void setup (Variable *input_pre, Variable *output_pre, Variable *input_post, Variable *output_post, const Shape_t &shape, const vector< int > &batch_axis, const int channel_axis, const Context &ctx)
Setup the adaptor.
This method must be called before the use of `forward_(pre|post)` or `backward_(pre|post)` methods. Setting up of transpose functions (`(pre|post)_transpose_`) are performed internally.
- Parameters:
input_pre – Variable pointer for `pre_transpose_` input.
output_pre – Variable pointer for `pre_transpose_` output.
input_post – Variable pointer for `post_transpose_` input.
output_post – Variable pointer for `post_transpose_` output.
shape – Shape of original input. This should be the same as the shape of `input_pre`.
batch_axis – List of integer corresponding to batch or outer axis. Each axis must be in range of [0, ndim).
channel_axis – An integer corresponding to channel axis. This axis must be in range of [0, ndim).
ctx – A compute backend descriptor.
- NBLA_API void convert_to_channel_first (Variable *input, Variable *output)
Transform variable memory format to channel-first.
- Parameters:
input – `pre_tranpose_` input.
output – `pre_tranpose_` output.
- NBLA_API void convert_from_channel_first (Variable *input, Variable *output)
Transform variable memory format from channel-first to original one.
- Parameters:
input – `post_tranpose_` input.
output – `post_tranpose_` output.
- NBLA_API void convert_to_channel_first_backward (Variable *input, Variable *output, const bool propagate_down, const bool accum)
Backward execution for `forward_pre`.
- Parameters:
input – `pre_tranpose_` input.
output – `pre_tranpose_` output.
propagate_down – Flag whether or not to perform backward propagation.
accum – Flag whether or not to accumulate grad.
- NBLA_API void convert_from_channel_first_backward (Variable *input, Variable *output, const bool propagate_down, const bool accum)
Backward execution for `forward_post`.
- Parameters:
input – `post_tranpose_` input.
output – `post_tranpose_` output.
propagate_down – Flag whether or not to perform backward propagation.
accum – Flag whether or not to accumulate grad.
Public Static Functions
- static NBLA_API bool need_adaptor (const Shape_t &shape, const vector< int > &batch_axis, const int channel_axis)
Check wheather this adaptor is needed for the input.
Returns `false` when the input is already channel-first format, otherwise `true`. The definition of channel-first memory format is when all of the following conditions are met. * `batch_axis` must be contiguous to the beginning of the shape. * `channel_axis` must be right after the last axis of `batch_axis`.
Examples of channel-first * ndim: 4, batch_axis: [0], channel_axis: 1 * ndim: 2, batch_axis: [0], channel_axis: 1 * ndim: 2, batch_axis: [], channel_axis: 0 * ndim: 1, batch_axis: [], channel_axis: 0 * ndim: 6, batch_axis: [0, 1, 2], channel_axis: 3 * ndim: 6, batch_axis: [1, 0, 2], channel_axis: 3
Examples of non channel-first * ndim: 4, batch_axis: [0], channel_axis: 3 * ndim: 4, batch_axis: [0], channel_axis: 2 * ndim: 4, batch_axis: [1], channel_axis: 0 * ndim: 2, batch_axis: [], channel_axis: 1 * ndim: 6, batch_axis: [0, 1, 3], channel_axis: 2
- Parameters:
shape – Shape of original input.
batch_axis – List of integer corresponding to batch or outer axis. Each axis must be unique and in range of [0, ndim).
channel_axis – An integer corresponding to channel axis. This axis must be in range of [0, ndim).
- Returns:
true
- Returns:
false