mlx.nn.ShardedToAllLinear

mlx.nn.ShardedToAllLinear#

class ShardedToAllLinear(input_dims: int, output_dims: int, bias: bool = True, group: Group | None = None)#

Each member of the group applies part of the affine transformation and then aggregates the results.

All nodes will have the same exact result after this layer.

ShardedToAllLinear provides a classmethod from_linear() to convert linear layers to sharded ShardedToAllLinear layers.

Parameters:
  • input_dims (int) – The dimensionality of the input features

  • output_dims (int) – The dimensionality of the output features

  • bias (bool, optional) – If set to False the the layer will not use a bias. Default is True.

  • group (mx.distributed.Group, optional) – The sharding will happen across this group. If not set then the global group is used. Default is None.

Methods

from_linear(linear_layer, *[, segments, group])