class nbla::Solver

class Solver

Solver interface which is extended to implement a new Solver class.

Solver takes care of update rule given gradients of parameters.

\[ w_{t+1} \leftarrow w_t - g_t(\Delta w_t) \]

The function \(g_t(\cdot)\) can have an internal state that is updated when it is called (e.g. Adam).

Subclassed by nbla::AMSBound< T >, nbla::AMSGRAD< T >, nbla::AdaBelief< T >, nbla::AdaBound< T >, nbla::Adadelta< T >, nbla::Adagrad< T >, nbla::Adam< T >, nbla::AdamW< T >, nbla::Adamax< T >, nbla::Lamb< T >, nbla::Lars< T >, nbla::Lion< T >, nbla::Momentum< T >, nbla::Nesterov< T >, nbla::RMSprop< T >, nbla::RMSpropGraves< T >, nbla::Sgd< T >, nbla::SgdW< T >

Public Functions

virtual ~Solver() = 0

Name of Solver class, usually class name.

virtual float learning_rate() = 0

Set learning rate.

bool weight_decay_is_fused() const

Whether the weight decay is lazily evaluated at update_impl.

void zero_grad()

Zeroing grads for all params_.

This is usually called before running a sequence of Function::backward() for propagating whole computation graph.

void set_parameters(const vector<pair<string, VariablePtr>> &params, bool reset = true, bool retain_state = false)

Adding parameters to be optimized via solver.

It calls set_impl().

Parameters:
  • params – Shared pointers of Variables

  • reset – Reset all parameters registered.

  • retain_state – Try to retain state (e.g. momentum) if a parameter is overwritten. Note that this will be ignored if reset=true.

void remove_parameters(const vector<string> &keys)

Remove previously registered parameters by keys.

void clear_parameters()

Clear all parameters.

vector<pair<string, VariablePtr>> get_parameters()

Get all parameters.

vector<pair<string, SolverState>> get_states()

Get all states.

void set_states(const vector<pair<string, SolverState>> &params)

Set states.

inline void clear_state(const string &key)

Clear states.

void update(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Update all params using stored grads in params_ by backpropagation.

This internally calls update_impl() which must be implemented in a derived class.

void weight_decay(float decay_rate, update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Apply weight decay to raw gradient.

It must be called before running update() if necessary. This internally calls weight_decay_impl() which must be implemented in a derived class.

It is equivalent to add a squared sum of weight vectors to original loss function.

\[ L_{\rm R}({\mathbf w}) = L_{\rm orig}({\mathbf w}) + {\rm decay\_rate } \times ||{\mathbf w}||_2^2 \]

Parameters:

decay_rate – Coefficient of weight decay.

void clip_grad_by_norm(float norm, update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Clip gradients by norm.

The norm is calculated at each variable.

bool check_inf_grad(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Check if there is any inf on the gradients which were setup.

bool check_nan_grad(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Check if there is any nan on the gradients which were setup.

bool check_inf_or_nan_grad(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Check if there is any inf or nan on the gradients which were setup.

void scale_grad(float scale, update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)

Scale gradients,then increase the loss scale.

virtual vector<string> allowed_array_classes()

Get array classes that are allowed to be specified by Context.

struct SolverState

Struct for storing both parameter state Variable and iteration.

Public Members

unordered_map<string, VariablePtr> pstate

Parameter state maps.

uint32_t t

Iteration as state.