mlx.optimizers.join_schedules#

join_schedules(schedules: List[Callable], boundaries: List[int]) Callable#

Join multiple schedules to create a new schedule.

Parameters:
  • schedules (list(Callable)) – A list of schedules. Schedule \(i+1\) receives a step count indicating the number of steps since the \(i\)-th boundary.

  • boundaries (list(int)) – A list of integers of length len(schedules) - 1 that indicates when to transition between schedules.

Example

>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(12): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999938, dtype=float32)