mlx.core.fast.metal_kernel

mlx.core.fast.metal_kernel#

class metal_kernel#

A jit-compiled custom Metal kernel defined from a source string.

__init__(self, name: str, source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) None#

Initialize a metal_kernel.

Parameters:
  • name (str) – Name for the kernel.

  • source (str) – Source code. This is the body of a function in Metal, the function signature will be generated for you. The names of the inputs/outputs are determined by the inputs and output_shapes/output_dtypes used when the kernel is called.

  • header (str) – Header source code to include before the main function. Useful for helper functions or includes that should live outside of the main function body.

  • ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous before the kernel runs. Default: True.

  • atomic_outputs (bool) – Whether to use atomic outputs in the function signature e.g. device atomic<float>. Default: False.

Returns:

Callable metal_kernel.

Example

def exp_elementwise(a: mx.array):
    source = '''
        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);
    '''

    kernel = mx.fast.metal_kernel(
        name="myexp",
        source=source
    )
    outputs = kernel(
        inputs={"inp": a},
        template={"T": mx.float32},
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes={"out": a.shape},
        output_dtypes={"out": a.dtype},
        verbose=True,
    )
    return outputs["out"]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

Methods

__init__(self, name, source[, header, ...])

Initialize a metal_kernel.