mlx_rs::ops

Function tensordot

Source
pub fn tensordot<'a>(
    a: impl AsRef<Array>,
    b: impl AsRef<Array>,
    axes: impl Into<TensorDotDims<'a>>,
) -> Result<Array>
Expand description

Compute the tensor dot product along the specified axes.

ยงParams

  • a: input array,
  • b: input array,
  • axes: The number of dimensions to sum over. If an integer is provided, then sum over the last axes dimensions of a and the first axes dimensions of b. If a tuple of lists is provided, then sum over the corresponding dimensions of a and b. (default: 2)