mlx.nn.AllToShardedLinear#
- class AllToShardedLinear(input_dims: int, output_dims: int, bias: bool = True, group: Group | None = None)#
Each member of the group applies part of the affine transformation such that the result is sharded across the group.
The gradients are automatically aggregated from each member of the group.
- Parameters:
input_dims (int) – The dimensionality of the input features
output_dims (int) – The dimensionality of the output features
bias (bool, optional) – If set to
Falsethe the layer will not use a bias. Default isTrue.group (mx.distributed.Group, optional) – The sharding will happen across this group. If not set then the global group is used. Default is
None.
Methods
from_linear(linear_layer, *[, segments, group])