mlx.nn.losses.triplet_loss

Contents

mlx.nn.losses.triplet_loss#

class triplet_loss(anchors: array, positives: array, negatives: array, axis: int = -1, p: int = 2, margin: float = 1.0, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'none')#

Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is represented with alpha in the math section.

\[\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\]
Parameters:
  • anchors (array) – The anchor samples.

  • positives (array) – The positive samples.

  • negatives (array) – The negative samples.

  • axis (int, optional) – The distribution axis. Default: -1.

  • p (int, optional) – The norm degree for pairwise distance. Default: 2.

  • margin (float, optional) – Margin for the triplet loss. Defaults to 1.0.

  • eps (float, optional) – Small positive constant to prevent numerical instability. Defaults to 1e-6.

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'none'.

Returns:

Computed triplet loss. If reduction is “none”, returns a tensor of the same shape as input;

if reduction is “mean” or “sum”, returns a scalar tensor.

Return type:

array