mlx.core.block_masked_mm#
- block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array | None = None, mask_lhs: array | None = None, mask_rhs: array | None = None, *, stream: None | Stream | Device = None) array#
Matrix multiplication with block masking.
Perform the (possibly batched) matrix multiplication of two arrays and with blocks of size
block_size x block_sizeoptionally masked out.Assuming
awith shape (…, M, K) and b with shape (…, K, N)lhs_maskmust have shape (…, \(\lceil\) M /block_size\(\rceil\), \(\lceil\) K /block_size\(\rceil\))rhs_maskmust have shape (…, \(\lceil\) K /block_size\(\rceil\), \(\lceil\) N /block_size\(\rceil\))out_maskmust have shape (…, \(\lceil\) M /block_size\(\rceil\), \(\lceil\) N /block_size\(\rceil\))
Note: Only
block_size=64andblock_size=32are currently supported- Parameters:
a (array) – Input array or scalar.
b (array) – Input array or scalar.
block_size (int) – Size of blocks to be masked. Must be
32or64. Default:64.mask_out (array, optional) – Mask for output. Default:
None.mask_lhs (array, optional) – Mask for
a. Default:None.mask_rhs (array, optional) – Mask for
b. Default:None.
- Returns:
The output array.
- Return type: