class nbla::Solver
-
class Solver
Solver interface which is extended to implement a new Solver class.
Solver takes care of update rule given gradients of parameters.
\[ w_{t+1} \leftarrow w_t - g_t(\Delta w_t) \]The function \(g_t(\cdot)\) can have an internal state that is updated when it is called (e.g. Adam).
Subclassed by nbla::AMSBound< T >, nbla::AMSGRAD< T >, nbla::AdaBelief< T >, nbla::AdaBound< T >, nbla::Adadelta< T >, nbla::Adagrad< T >, nbla::Adam< T >, nbla::AdamW< T >, nbla::Adamax< T >, nbla::Lamb< T >, nbla::Lars< T >, nbla::Lion< T >, nbla::Momentum< T >, nbla::Nesterov< T >, nbla::RMSprop< T >, nbla::RMSpropGraves< T >, nbla::Sgd< T >, nbla::SgdW< T >
Public Functions
-
virtual float learning_rate() = 0
Set learning rate.
-
bool weight_decay_is_fused() const
Whether the weight decay is lazily evaluated at update_impl.
-
void zero_grad()
Zeroing grads for all params_.
This is usually called before running a sequence of Function::backward() for propagating whole computation graph.
-
void set_parameters(const vector<pair<string, VariablePtr>> ¶ms, bool reset = true, bool retain_state = false)
Adding parameters to be optimized via solver.
It calls set_impl().
- Parameters:
params – Shared pointers of Variables
reset – Reset all parameters registered.
retain_state – Try to retain state (e.g. momentum) if a parameter is overwritten. Note that this will be ignored if reset=true.
-
void remove_parameters(const vector<string> &keys)
Remove previously registered parameters by keys.
-
void clear_parameters()
Clear all parameters.
-
vector<pair<string, VariablePtr>> get_parameters()
Get all parameters.
-
vector<pair<string, SolverState>> get_states()
Get all states.
-
void set_states(const vector<pair<string, SolverState>> ¶ms)
Set states.
-
inline void clear_state(const string &key)
Clear states.
-
void update(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Update all params using stored grads in params_ by backpropagation.
This internally calls update_impl() which must be implemented in a derived class.
-
void weight_decay(float decay_rate, update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Apply weight decay to raw gradient.
It must be called before running update() if necessary. This internally calls weight_decay_impl() which must be implemented in a derived class.
It is equivalent to add a squared sum of weight vectors to original loss function.
\[ L_{\rm R}({\mathbf w}) = L_{\rm orig}({\mathbf w}) + {\rm decay\_rate } \times ||{\mathbf w}||_2^2 \]- Parameters:
decay_rate – Coefficient of weight decay.
-
void clip_grad_by_norm(float norm, update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Clip gradients by norm.
The norm is calculated at each variable.
-
bool check_inf_grad(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Check if there is any inf on the gradients which were setup.
-
bool check_nan_grad(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Check if there is any nan on the gradients which were setup.
-
bool check_inf_or_nan_grad(update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Check if there is any inf or nan on the gradients which were setup.
-
void scale_grad(float scale, update_hook_type pre_callback = nullptr, update_hook_type post_callback = nullptr)
Scale gradients,then increase the loss scale.
-
struct SolverState
Struct for storing both parameter state Variable and iteration.
Public Members
-
unordered_map<string, VariablePtr> pstate
Parameter state maps.
-
uint32_t t
Iteration as state.
-
unordered_map<string, VariablePtr> pstate
-
virtual float learning_rate() = 0