class nbla::Variable
-
class Variable : public nbla::BaseVariable
User interface for Array and passed to Function.Shared pointer of Variable.
Users will create arrays via Variable and pass them to Function. Variable has two array region internally, data and grad. Data region is used as an input and/or output of Function::forward(), while grad region is used for storing backprop error of Function::backward().
Public Functions
-
NBLA_API Variable(NdArrayPtr data)
Constructor given NdArray.
- パラメータ:
data -- A reference of NdArray created by another can be passed.
- NBLA_API void reshape (const vector< int64_t > &shape, bool force)
- NBLA_API Ptr view ()
Create a new view object without copying data.
- NBLA_API Ptr view (const Shape_t &shape)
Create a new view object given shape without copying data.
- パラメータ:
shape -- Shape. The total size of the shape must match the size of this instance.
- NBLA_API Size_t size (Size_t axis=-1) const
Size of Array (Product of shape dimensions).
- パラメータ:
axis -- Size followed by given axis is computed.
-
template<typename T>
inline T *cast_data_and_get_pointer(const Context &ctx, bool write_only = false) A shortcut function to cast data and get pointer.
参考
-
template<typename T>
inline T *cast_grad_and_get_pointer(const Context &ctx, bool write_only = false) A shortcut function to cast grad and get pointer.
参考
-
NBLA_API Variable(NdArrayPtr data)