mlx.core.fast.metal_kernel#
- class metal_kernel#
A jit-compiled custom Metal kernel defined from a source string.
- __init__(self, name: str, source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) None #
Initialize a metal_kernel.
- Parameters:
name (str) – Name for the kernel.
source (str) – Source code. This is the body of a function in Metal, the function signature will be generated for you. The names of the inputs/outputs are determined by the
inputs
andoutput_shapes
/output_dtypes
used when the kernel is called.header (str) – Header source code to include before the main function. Useful for helper functions or includes that should live outside of the main function body.
ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous before the kernel runs. Default:
True
.atomic_outputs (bool) – Whether to use atomic outputs in the function signature e.g.
device atomic<float>
. Default:False
.
- Returns:
Callable
metal_kernel
.
Example
def exp_elementwise(a: mx.array): source = ''' uint elem = thread_position_in_grid.x; T tmp = inp[elem]; out[elem] = metal::exp(tmp); ''' kernel = mx.fast.metal_kernel( name="myexp", source=source ) outputs = kernel( inputs={"inp": a}, template={"T": mx.float32}, grid=(a.size, 1, 1), threadgroup=(256, 1, 1), output_shapes={"out": a.shape}, output_dtypes={"out": a.dtype}, verbose=True, ) return outputs["out"] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a))
Methods
__init__
(self, name, source[, header, ...])Initialize a metal_kernel.