class nbla::Variable
-
class Variable : public nbla::BaseVariable
User interface for Array and passed to Function.Shared pointer of Variable.
Users will create arrays via Variable and pass them to Function. Variable has two array region internally, data and grad. Data region is used as an input and/or output of Function::forward(), while grad region is used for storing backprop error of Function::backward().
Public Functions
-
NBLA_API Variable(NdArrayPtr data)
Constructor given NdArray.
- Parameters:
data – A reference of NdArray created by another can be passed.
- NBLA_API void reshape (const vector< int64_t > &shape, bool force)
- NBLA_API Ptr view ()
Create a new view object without copying data.
- NBLA_API Ptr view (const Shape_t &shape)
Create a new view object given shape without copying data.
- Parameters:
shape – Shape. The total size of the shape must match the size of this instance.
- NBLA_API Size_t size (Size_t axis=-1) const
Size of Array (Product of shape dimensions).
- Parameters:
axis – Size followed by given axis is computed.
-
template<typename T>
inline T *cast_data_and_get_pointer(const Context &ctx, bool write_only = false) A shortcut function to cast data and get pointer.
See also
-
template<typename T>
inline T *cast_grad_and_get_pointer(const Context &ctx, bool write_only = false) A shortcut function to cast grad and get pointer.
See also
-
template<typename T>
inline const T *get_data_pointer(const Context &ctx) A shortcut function to get data pointer.
See also
-
NBLA_API Variable(NdArrayPtr data)