mlx.nn.layers.distributed.shard_inplace

Contents

mlx.nn.layers.distributed.shard_inplace#

shard_inplace(module: Module, sharding: str | Callable, *, segments: int | list = 1, group: Group | None = None)#

Shard a module in-place by updating its parameter dictionary with the sharded parameter dictionary.

The sharding argument can be any callable that given the path and the weight returns the sharding axis and optionally also the segments that comprise the unsharded weight. For instance if the weight is a fused QKV matrix the segments should be 3.

Note

The module doesn’t change so in order for distributed communication to happen the module needs to natively support it and for it to be enabled.

Parameters:
  • module (Module) – The parameters of this module will be sharded in-place.

  • sharding (str or callable) – One of “all-to-sharded” and “sharded-to-all” or a callable that returns the sharding axis and segments.

  • segments (int or list) – The segments to use if sharding is a string. Default: 1.

  • group (Group) – The distributed group to shard across. If not set, the global group will be used. Default: None.