mlx_rs::ops

Function conv3d

Source
pub fn conv3d(
    array: impl AsRef<Array>,
    weight: impl AsRef<Array>,
    stride: impl Into<Option<(i32, i32, i32)>>,
    padding: impl Into<Option<(i32, i32, i32)>>,
    dilation: impl Into<Option<(i32, i32, i32)>>,
    groups: impl Into<Option<i32>>,
) -> Result<Array>
Expand description

3D convolution over an input with several channels.

Only the default groups=1 is currently supported.