mlx.core.fast.scaled_dot_product_attention

mlx.core.fast.scaled_dot_product_attention#

scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: array | None = None, stream: None | Stream | Device = None) array#

A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.

Supports:

Note: The softmax operation is performed in float32 regardless of the input precision.

Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.

In the following the dimensions are given by:

  • B: The batch size.

  • N_q: The number of query heads.

  • N_kv: The number of key and value heads.

  • T_q: The number of queries per example.

  • T_kv: The number of keys and values per example.

  • D: The per-head dimension.

Parameters:
  • q (array) – Queries with shape [B, N_q, T_q, D].

  • k (array) – Keys with shape [B, N_kv, T_kv, D].

  • v (array) – Values with shape [B, N_kv, T_kv, D].

  • scale (float) – Scale for queries (typically 1.0 / sqrt(q.shape(-1))

  • mask (array, optional) – A boolean or additive mask to apply to the query-key scores. The mask can have at most 4 dimensions and must be broadcast-compatible with the shape [B, N, T_q, T_kv]. If an additive mask is given its type must promote to the promoted type of q, k, and v.

Returns:

The output array.

Return type:

array