mlx.nn.losses.cross_entropy#
- class cross_entropy(logits: array, targets: array, weights: array | None = None, axis: int = -1, label_smoothing: float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')#
Computes the cross entropy loss.
- Parameters:
logits (array) – The unnormalized logits.
targets (array) – The ground truth values. These can be class indices or probabilities for each class. If the
targetsare class indices, thentargetsshape should match thelogitsshape with theaxisdimension removed. If thetargetsare probabilities (or one-hot encoded), then thetargetsshape should be the same as thelogitsshape.weights (array, optional) – Optional weights for each target. Default:
None.axis (int, optional) – The axis over which to compute softmax. Default:
-1.label_smoothing (float, optional) – Label smoothing factor. Default:
0.reduction (str, optional) – Specifies the reduction to apply to the output:
'none'|'mean'|'sum'. Default:'none'.
- Returns:
The computed cross entropy loss.
- Return type:
Examples
>>> import mlx.core as mx >>> import mlx.nn as nn >>> >>> # Class indices as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([0, 1]) >>> nn.losses.cross_entropy(logits, targets) array([0.0485873, 0.0485873], dtype=float32) >>> >>> # Probabilities (or one-hot vectors) as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]]) >>> nn.losses.cross_entropy(logits, targets) array([0.348587, 0.348587], dtype=float32)