class nbla::Lamb

template<typename T>
class Lamb : public nbla::Solver

LAMB.

\[\begin{split} m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ \hat{m} &= m_t / (1-\beta_1^t)\\ \hat{v} &= v_t / (1-\beta_2^t)\\ r &= \frac{\hat{m}}{\sqrt{\hat{v}}+\epsilon}\\ w_t &\leftarrow w_{t-1} - \eta_t \frac{\phi (\|w_{t-1}\|)}{\|r + \lambda w_{t-1} \|} \left(r + \lambda w_{t-1} \right) \end{split}\]

where \(g_t\) denotes a gradient, \(m_t\) and \(v_t\) are 1st and 2nd order momentum of the gradient initialized with 0 at \(t=0\), \(\lambda\) is the decoupled weight decay rate set by weight_decay method (lazy evaluation), \(\phi\) is a scaling function defined as \(\phi(z)=\min\{\max\{z, \gamma_l\}, \gamma_u\}\), and the rest is described in the arguments.

See also

See the paper linked below for more details. Yang You, Jing Li, Sashank Reddi. Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962

Param eta:

Learning rate ( \(\eta_t\)).

Param beta1:

Decay rate of first-order momentum ( \(\beta_1\)).

Param beta2:

Decay rate of second-order momentum ( \(\beta_2\)).

Param gamma_l:

Lower bound of the clamp scaling function \(\phi\) ( \(\gamma_l\)).

Param gamma_u:

Upper bound of the clamp scaling function \(\phi\) ( \(\gamma_u\)).

Param eps:

Small value for avoiding zero division ( \(\epsilon\)).

Param bias_correction:

Whether to apply bias correction in momentum computation \(\hat{m}\) and \(\hat{v}\).

Public Functions

inline virtual float learning_rate()

Set learning rate.