mlx.core.block_masked_mm#

block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) array#

Matrix multiplication with block masking.

Perform the (possibly batched) matrix multiplication of two arrays and with blocks of size block_size x block_size optionally masked out.

Assuming a with shape (…, M, K) and b with shape (…, K, N)

  • lhs_mask must have shape (…, \(\lceil\) M / block_size \(\rceil\), \(\lceil\) K / block_size \(\rceil\))

  • rhs_mask must have shape (…, \(\lceil\) K / block_size \(\rceil\), \(\lceil\) N / block_size \(\rceil\))

  • out_mask must have shape (…, \(\lceil\) M / block_size \(\rceil\), \(\lceil\) N / block_size \(\rceil\))

Note: Only block_size=64 and block_size=32 are currently supported

Parameters:
  • a (array) – Input array or scalar.

  • b (array) – Input array or scalar.

  • block_size (int) – Size of blocks to be masked. Must be 32 or 64 (default: 64)

  • mask_out (array, optional) – Boolean mask for output (default: None)

  • mask_lhs (array, optional) – Boolean mask for a (default: None)

  • mask_rhs (array, optional) – Boolean mask for b (default: None)