mlx.core.fast.scaled_dot_product_attention#
- scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: array | None = None, stream: None | Stream | Device = None) array #
A fast implementation of multi-head attention:
O = softmax(Q @ K.T, dim=-1) @ V
.Supports:
Note: The softmax operation is performed in
float32
regardless of the input precision.Note: For Grouped Query Attention and Multi-Query Attention, the
k
andv
inputs should not be pre-tiled to matchq
.In the following the dimensions are given by:
B
: The batch size.N_q
: The number of query heads.N_kv
: The number of key and value heads.T_q
: The number of queries per example.T_kv
: The number of keys and values per example.D
: The per-head dimension.
- Parameters:
q (array) – Queries with shape
[B, N_q, T_q, D]
.k (array) – Keys with shape
[B, N_kv, T_kv, D]
.v (array) – Values with shape
[B, N_kv, T_kv, D]
.scale (float) – Scale for queries (typically
1.0 / sqrt(q.shape(-1)
)mask (array, optional) – A boolean or additive mask to apply to the query-key scores. The mask can have at most 4 dimensions and must be broadcast-compatible with the shape
[B, N, T_q, T_kv]
. If an additive mask is given its type must promote to the promoted type ofq
,k
, andv
.
- Returns:
The output array.
- Return type: