mlx.nn.losses.binary_cross_entropy

mlx.nn.losses.binary_cross_entropy#

class binary_cross_entropy(inputs: array, targets: array, weights: array | None = None, with_logits: bool = True, reduction: Literal['none', 'mean', 'sum'] = 'mean')#

Computes the binary cross entropy loss.

By default, this function takes the pre-sigmoid logits, which results in a faster and more precise loss. For improved numerical stability when with_logits=False, the loss calculation clips the input probabilities (in log-space) to a minimum value of -100.

Parameters:
  • inputs (array) – The predicted values. If with_logits is True, then inputs are unnormalized logits. Otherwise, inputs are probabilities.

  • targets (array) – The binary target values in {0, 1}.

  • with_logits (bool, optional) – Whether inputs are logits. Default: True.

  • weights (array, optional) – Optional weights for each target. Default: None.

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'.

Returns:

The computed binary cross entropy loss.

Return type:

array

Examples

>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean")
>>> loss
array(0.539245, dtype=float32)
>>> probs = mx.array([0.1, 0.1, 0.4, 0.4])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean")
>>> loss
array(0.510826, dtype=float32)