mlx.core.block_masked_mm#
- block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) array #
Matrix multiplication with block masking.
Perform the (possibly batched) matrix multiplication of two arrays and with blocks of size
block_size x block_size
optionally masked out.Assuming
a
with shape (…, M, K) and b with shape (…, K, N)lhs_mask
must have shape (…, \(\lceil\) M /block_size
\(\rceil\), \(\lceil\) K /block_size
\(\rceil\))rhs_mask
must have shape (…, \(\lceil\) K /block_size
\(\rceil\), \(\lceil\) N /block_size
\(\rceil\))out_mask
must have shape (…, \(\lceil\) M /block_size
\(\rceil\), \(\lceil\) N /block_size
\(\rceil\))
Note: Only
block_size=64
andblock_size=32
are currently supported- Parameters:
a (array) – Input array or scalar.
b (array) – Input array or scalar.
block_size (int) – Size of blocks to be masked. Must be
32
or64
(default:64
)mask_out (array, optional) – Boolean mask for output (default:
None
)mask_lhs (array, optional) – Boolean mask for a (default:
None
)mask_rhs (array, optional) – Boolean mask for b (default:
None
)