mlx.core.fast.scaled_dot_product_attention#

scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: None | array = None, stream: None | Stream | Device = None) array#

A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.

Supports:

Note: The softmax operation is performed in float32 regardless of the input precision.

Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.

Parameters:
  • q (array) – Input query array.

  • k (array) – Input keys array.

  • v (array) – Input values array.

  • scale (float) – Scale for queries (typically 1.0 / sqrt(q.shape(-1))

  • mask (array, optional) – An additive mask to apply to the query-key scores.

Returns:

The output array.

Return type:

array