class nbla::Lamb
-
template<typename T>
class Lamb : public nbla::Solver LAMB.
\[\begin{split} m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ \hat{m} &= m_t / (1-\beta_1^t)\\ \hat{v} &= v_t / (1-\beta_2^t)\\ r &= \frac{\hat{m}}{\sqrt{\hat{v}}+\epsilon}\\ w_t &\leftarrow w_{t-1} - \eta_t \frac{\phi (\|w_{t-1}\|)}{\|r + \lambda w_{t-1} \|} \left(r + \lambda w_{t-1} \right) \end{split}\]where \(g_t\) denotes a gradient, \(m_t\) and \(v_t\) are 1st and 2nd order momentum of the gradient initialized with 0 at \(t=0\), \(\lambda\) is the decoupled weight decay rate set by weight_decay method (lazy evaluation), \(\phi\) is a scaling function defined as \(\phi(z)=\min\{\max\{z, \gamma_l\}, \gamma_u\}\), and the rest is described in the arguments.
See also
See the paper linked below for more details. Yang You, Jing Li, Sashank Reddi. Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962
- Param eta:
Learning rate ( \(\eta_t\)).
- Param beta1:
Decay rate of first-order momentum ( \(\beta_1\)).
- Param beta2:
Decay rate of second-order momentum ( \(\beta_2\)).
- Param gamma_l:
Lower bound of the clamp scaling function \(\phi\) ( \(\gamma_l\)).
- Param gamma_u:
Upper bound of the clamp scaling function \(\phi\) ( \(\gamma_u\)).
- Param eps:
Small value for avoiding zero division ( \(\epsilon\)).
- Param bias_correction:
Whether to apply bias correction in momentum computation \(\hat{m}\) and \(\hat{v}\).
Public Functions
-
inline virtual float learning_rate()
Set learning rate.