class nbla::CgVariable

class CgVariable : public nbla::BaseCgVariable

Computation graph variable.

A Variable object is held in this object as a data container. In addition, a CGVariable object keeps information about the computation graph it belongs to. The information if such as the pointer to the parent function which creates this variable and some performance optimization clues.

Public Functions

NBLA_API CgVariable()

Create 0-shaped variable with no need_grad flag.

NBLA_API CgVariable(bool need_grad)

Create 0-shaped variable with need_grad option.

Parameters:

need_grad[in] Whether this variable requires gradient computation or not

NBLA_API CgVariable(Shape_t shape)

Create a variable by shape.

Parameters:

shape[in] Shape passed to Variable object held in the created instance.

NBLA_API CgVariable(Shape_t shape, bool need_grad)

Create a variable by shape with need_grad option.

Parameters:
  • shape[in] Shape passed to Variable object held in the created instance.

  • need_grad[in] Whether this variable requires gradient computation or not

NBLA_API CgVariable(VariablePtr var)

Create by a Variable instance.

Parameters:

var[in] Reference of an existing Variable object.

NBLA_API CgVariable(VariablePtr var, bool need_grad)

Create by a Variable instance.

Parameters:
  • var[in] Reference of an existing Variable object.

  • need_grad[in] Whether this variable requires gradient computation or not

inline bool need_grad() const

Get need grad flag.

inline bool need_grad_is_set() const

Check if need grad flag is set.

inline void set_need_grad(bool b)

Set need grad flag.

inline void unset_need_grad()

Unset need grad flag.

inline bool need_grad_state() const

Get need grad state flag.

inline bool need_grad_state_is_set() const

Check if need grad state is set.

inline void set_need_grad_state(bool b)

Set need grad state flag.

inline void unset_need_grad_state()

Unset need grad state flag.

inline bool recompute() const

Get recompute flag.

inline void set_recompute(bool b)

Set recompute flag.

inline bool prohibit_clear_data()

Get prohibit_clear_data_ flag.

inline void set_prohibit_clear_data(bool b)

Set prohibit_clear_data_ flag.

inline void set_parent(CgFunctionPtr func)

Set parent function.

Note

Users usually don’t use this directly. Used in connect function.

Parameters:

func[in] Function.

inline CgFunctionPtr parent()

Get parent function which produces outputs to this variable.

inline bool has_parent()

Query if a parent function is set.

inline int rank() const

Longest path from root variable.

Holds weak function references. https://stackoverflow.com/a/22110715

inline void set_rank_(int rank)

set rank.

Note

Users shouldn’t call this directly.

NBLA_API void forward (bool clear_buffer=false, bool clear_no_need_grad=false, unordered_set< CgFunctionPtr > *fclosed=nullptr, function_hook_type pre_callback=nullptr, function_hook_type post_callback=nullptr)

Forward propagation from root inputs to this variable.

The predecessor functions are executed in order of lower rank to higher rank until reaching this variable.

@seealso set_persistent() to prevent a specific variable to be cleared during forward propagation.

Parameters:
  • clear_buffer[in] Clear SyncedArray object of a variable never be used during the rest of forward propagation. This option significantly saves the memory consumption. This is not usually used in training phase because backward computation requires data computed during forward prop.

  • clear_need_grad[in] Clear the unreferenced variables with need_grad=False during forward propagation. True is usually used when calling this during training. This is ignored when clear_buffer=True.

  • fclosed[in] Set arbitrary fclosed flags to control forward computation. This is used for forward_all function.

NBLA_API void backward (NdArrayPtr grad=nullptr, bool clear_buffer=false, vector< CommunicatorBackwardCallbackPtr > communicator_callbacks={}, function_hook_type pre_callback=nullptr, function_hook_type post_callback=nullptr, const bool clear_initial_grad=false)

Performs a backward propagation.

starting from this variable until the root variable(s) is/are reached in the computation graph. The propagation will stop at a variable with need_grad=false. Backward propagation through predecessors of this variable.

@seealso set_persistent() to prevent a specific variable to be cleared during forward propagation.

Parameters:
  • grad[in] The backward error signal of this variable. if nullptr is set, its gradients are set as 1.

  • clear_buffer[in] Clears the no longer referenced variables during backpropagation to save memory.

  • communicator_callbacks – The callback functions invoked when 1) backward computation of each function is finished and 2) all backward computation is finished.

  • clear_initial_grad – If true, the input parameter, grad, will be cleared during backward propagation. This flag is only activated when grad is set.

void mark_need_setup()

Mark need_setup flag for all function references.

bool check_and_unmark_need_setup(CgFunctionPtr func)

Check need_setup signal, and unmark it.

inline bool allow_modify_data() const

Whether the data can be in-placed.

inline void set_allow_modify_data(bool allow)

Note

User shouldn’t call this directly.

inline void set_persistent(bool p)

Set persistent flag.

If it’s true, the variable data and grad are never cleared during forward or backward propagation with clear options. It is useful for visualization and debugging purposes.

Parameters:

p[in] Persistent flag.

inline bool persistent() const

Get persistent flag.

NBLA_API void clear_during_auto_forward ()

Clear the memory during forward propagation in auto-forward mode.

inline void set_name(string name)

Set variable name.

inline string name() const

Get variable name.

NBLA_API Ptr create_deep_copy (Context ctx, bool copy_grad=true)

Deepcopy method.

void visit_function_recursive(CgFunctionPtr func, unordered_set<CgFunctionPtr> &fclosed, const bool recomputation, function<void(CgFunctionPtr)> forward_callback)

Execute callback at functions in forward order in a graph.

void visit_function_backward(CgFunctionPtr func, function<void(CgFunctionPtr)> backward_callback, vector<CommunicatorBackwardCallbackPtr> communicator_callbacks)

Execute callback at functions in backward order in a graph.