mlx.core.tensordot

Contents

mlx.core.tensordot#

tensordot(a: array, b: array, /, axes: int | list[Sequence[int]] = 2, *, stream: None | Stream | Device = None) array#

Compute the tensor dot product along the specified axes.

Parameters:
  • a (array) – Input array

  • b (array) – Input array

  • axes (int or list(list(int)), optional) – The number of dimensions to sum over. If an integer is provided, then sum over the last axes dimensions of a and the first axes dimensions of b. If a list of lists is provided, then sum over the corresponding dimensions of a and b. Default: 2.

Returns:

The tensor dot product.

Return type:

array