mlx.core.fast.scaled_dot_product_attention#
- scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: array | None = None, stream: None | Stream | Device = None) array #
A fast implementation of multi-head attention:
O = softmax(Q @ K.T, dim=-1) @ V
.Supports:
Note: The softmax operation is performed in
float32
regardless of the input precision.Note: For Grouped Query Attention and Multi-Query Attention, the
k
andv
inputs should not be pre-tiled to matchq
.