mlx.nn.losses.binary_cross_entropy#
- class binary_cross_entropy(inputs: array, targets: array, weights: array | None = None, with_logits: bool = True, reduction: Literal['none', 'mean', 'sum'] = 'mean')#
Computes the binary cross entropy loss.
By default, this function takes the pre-sigmoid logits, which results in a faster and more precise loss. For improved numerical stability when
with_logits=False
, the loss calculation clips the input probabilities (in log-space) to a minimum value of-100
.- Parameters:
inputs (array) – The predicted values. If
with_logits
isTrue
, theninputs
are unnormalized logits. Otherwise,inputs
are probabilities.targets (array) – The binary target values in {0, 1}.
with_logits (bool, optional) – Whether
inputs
are logits. Default:True
.weights (array, optional) – Optional weights for each target. Default:
None
.reduction (str, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
|'sum'
. Default:'mean'
.
- Returns:
The computed binary cross entropy loss.
- Return type:
Examples
>>> import mlx.core as mx >>> import mlx.nn as nn
>>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) >>> targets = mx.array([0, 0, 1, 1]) >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean") >>> loss array(0.539245, dtype=float32)
>>> probs = mx.array([0.1, 0.1, 0.4, 0.4]) >>> targets = mx.array([0, 0, 1, 1]) >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean") >>> loss array(0.510826, dtype=float32)