mlx_rs::ops

Function block_masked_mm_device

Source
pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
    a: impl AsRef<Array>,
    b: impl AsRef<Array>,
    block_size: impl Into<Option<i32>>,
    mask_out: impl Into<Option<&'mo Array>>,
    mask_lhs: impl Into<Option<&'lhs Array>>,
    mask_rhs: impl Into<Option<&'rhs Array>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Matrix multiplication with block masking.

See the python API docs for more information.