mlx.optimizers.AdamW

Contents

mlx.optimizers.AdamW#

class AdamW(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], eps: float = 1e-08, weight_decay: float = 0.01, bias_correction: bool = False)#

The AdamW optimizer [1]. We update the weights with a weight_decay (λ) value:

[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019.

mt+1=β1mt+(1β1)gtvt+1=β2vt+(1β2)gt2wt+1=wtα(mt+1vt+1+ϵ+λwt)
Parameters:
  • learning_rate (float or callable) – The learning rate α.

  • betas (Tuple[float, float], optional) – The coefficients (β1,β2) used for computing running averages of the gradient and its square. Default: (0.9, 0.999)

  • eps (float, optional) – The term ϵ added to the denominator to improve numerical stability. Default: 1e-8

  • weight_decay (float, optional) – The weight decay λ. Default: 0.

  • bias_correction (bool, optional) – If set to True, bias correction is applied. Default: False

Methods

__init__(learning_rate[, betas, eps, ...])

apply_single(gradient, parameter, state)

Performs the AdamW parameter update by modifying the parameters passed into Adam.