mlx.nn.losses.cross_entropy#
- class cross_entropy(logits: array, targets: array, weights: array | None = None, axis: int = -1, label_smoothing: float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')#
Computes the cross entropy loss.
- Parameters:
logits (array) – The unnormalized logits.
targets (array) – The ground truth values. These can be class indices or probabilities for each class. If the
targets
are class indices, thentargets
shape should match thelogits
shape with theaxis
dimension removed. If thetargets
are probabilities (or one-hot encoded), then thetargets
shape should be the same as thelogits
shape.weights (array, optional) – Optional weights for each target. Default:
None
.axis (int, optional) – The axis over which to compute softmax. Default:
-1
.label_smoothing (float, optional) – Label smoothing factor. Default:
0
.reduction (str, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
|'sum'
. Default:'none'
.
- Returns:
The computed cross entropy loss.
- Return type:
Examples
>>> import mlx.core as mx >>> import mlx.nn as nn >>> >>> # Class indices as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([0, 1]) >>> nn.losses.cross_entropy(logits, targets) array([0.0485873, 0.0485873], dtype=float32) >>> >>> # Probabilities (or one-hot vectors) as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]]) >>> nn.losses.cross_entropy(logits, targets) array([0.348587, 0.348587], dtype=float32)