mlx.optimizers.MultiOptimizer

Contents

mlx.optimizers.MultiOptimizer#

class MultiOptimizer(optimizers, filters: list = [])#

Wraps a list of optimizers with corresponding weight predicates/filters to make it easy to use different optimizers for different weights.

The predicates take the full “path” of the weight and the weight itself and return True if it should be considered for this optimizer. The last optimizer in the list is a fallback optimizer and no predicate should be given for it.

Parameters:
  • optimizers (list[Optimizer]) – A list of optimizers to delegate to

  • filters (list[Callable[[str, array], bool]) – A list of predicates that should be one less than the provided optimizers.

Methods

__init__(optimizers[, filters])

apply_gradients(gradients, parameters)

Apply the gradients to the parameters and return the updated parameters.

init(parameters)

Initialize the optimizer's state