mlx.nn.AllToShardedLinear

mlx.nn.AllToShardedLinear#

class AllToShardedLinear(input_dims: int, output_dims: int, bias: bool = True, group: Group | None = None)#

Each member of the group applies part of the affine transformation such that the result is sharded across the group.

The gradients are automatically aggregated from each member of the group.

Parameters:
  • input_dims (int) – The dimensionality of the input features

  • output_dims (int) – The dimensionality of the output features

  • bias (bool, optional) – If set to False the the layer will not use a bias. Default is True.

  • group (mx.distributed.Group, optional) – The sharding will happen across this group. If not set then the global group is used. Default is None.

Methods

from_linear(linear_layer, *[, segments, group])