mlx.nn.layers.distributed.shard_linear

Contents

mlx.nn.layers.distributed.shard_linear#

shard_linear(module: Module, sharding: str, *, segments: int | list = 1, group: Group | None = None)#

Create a new linear layer that has its parameters sharded and also performs distributed communication either in the forward or backward pass.

Note

Contrary to shard_inplace, the original layer is not changed but a new layer is returned.

Parameters:
  • module (Module) – The linear layer to be sharded.

  • sharding (str) – One of “all-to-sharded” and “sharded-to-all” that defines the type of sharding to perform.

  • segments (int or list) – The segments to use. Default: 1.

  • group (Group) – The distributed group to shard across. If not set, the global group will be used. Default: None.