mlx.core.fast.metal_kernel#

metal_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) object#

A jit-compiled custom Metal kernel defined from a source string.

Parameters:
  • name (str) – Name for the kernel.

  • input_names (List[str]) – The parameter names of the inputs in the function signature.

  • output_names (List[str]) – The parameter names of the outputs in the function signature.

  • source (str) – Source code. This is the body of a function in Metal, the function signature will be automatically generated.

  • header (str) – Header source code to include before the main function. Useful for helper functions or includes that should live outside of the main function body.

  • ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous before the kernel runs. Default: True.

  • atomic_outputs (bool) – Whether to use atomic outputs in the function signature e.g. device atomic<float>. Default: False.

Returns:

Callable metal_kernel.

Example

def exp_elementwise(a: mx.array):
    source = '''
        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);
    '''

    kernel = mx.fast.metal_kernel(
        name="myexp",
        input_names=["inp"],
        output_names=["out"],
        source=source
    )
    outputs = kernel(
        inputs=[a],
        template=[("T", mx.float32)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
        verbose=True,
    )
    return outputs[0]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))