mlx.nn.layers.distributed.shard_inplace#
- shard_inplace(module: Module, sharding: str | Callable, *, segments: int | list = 1, group: Group | None = None)#
Shard a module in-place by updating its parameter dictionary with the sharded parameter dictionary.
The
shardingargument can be any callable that given the path and the weight returns the sharding axis and optionally also the segments that comprise the unsharded weight. For instance if the weight is a fused QKV matrix the segments should be 3.Note
The module doesn’t change so in order for distributed communication to happen the module needs to natively support it and for it to be enabled.
- Parameters:
module (Module) – The parameters of this module will be sharded in-place.
sharding (str or callable) – One of “all-to-sharded” and “sharded-to-all” or a callable that returns the sharding axis and segments.
segments (int or list) – The segments to use if
shardingis a string. Default:1.group (Group) – The distributed group to shard across. If not set, the global group will be used. Default:
None.