mlx.nn.losses.cross_entropy#
- class cross_entropy(logits: array, targets: array, weights: array | None = None, axis: int = -1, label_smoothing: float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')#
- Computes the cross entropy loss. - Parameters:
- logits (array) – The unnormalized logits. 
- targets (array) – The ground truth values. These can be class indices or probabilities for each class. If the - targetsare class indices, then- targetsshape should match the- logitsshape with the- axisdimension removed. If the- targetsare probabilities (or one-hot encoded), then the- targetsshape should be the same as the- logitsshape.
- weights (array, optional) – Optional weights for each target. Default: - None.
- axis (int, optional) – The axis over which to compute softmax. Default: - -1.
- label_smoothing (float, optional) – Label smoothing factor. Default: - 0.
- reduction (str, optional) – Specifies the reduction to apply to the output: - 'none'|- 'mean'|- 'sum'. Default:- 'none'.
 
- Returns:
- The computed cross entropy loss. 
- Return type:
 - Examples - >>> import mlx.core as mx >>> import mlx.nn as nn >>> >>> # Class indices as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([0, 1]) >>> nn.losses.cross_entropy(logits, targets) array([0.0485873, 0.0485873], dtype=float32) >>> >>> # Probabilities (or one-hot vectors) as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]]) >>> nn.losses.cross_entropy(logits, targets) array([0.348587, 0.348587], dtype=float32) 
