mlx.nn.ShardedToAllLinear#
- class ShardedToAllLinear(input_dims: int, output_dims: int, bias: bool = True, group: Group | None = None)#
Each member of the group applies part of the affine transformation and then aggregates the results.
All nodes will have the same exact result after this layer.
ShardedToAllLinearprovides a classmethodfrom_linear()to convert linear layers to shardedShardedToAllLinearlayers.- Parameters:
input_dims (int) – The dimensionality of the input features
output_dims (int) – The dimensionality of the output features
bias (bool, optional) – If set to
Falsethe the layer will not use a bias. Default isTrue.group (mx.distributed.Group, optional) – The sharding will happen across this group. If not set then the global group is used. Default is
None.
Methods
from_linear(linear_layer, *[, segments, group])