mlx_rs::ops

Function conv_transpose2d_device

Source
pub fn conv_transpose2d_device(
    array: impl AsRef<Array>,
    weight: impl AsRef<Array>,
    stride: impl Into<Option<(i32, i32)>>,
    padding: impl Into<Option<(i32, i32)>>,
    dilation: impl Into<Option<(i32, i32)>>,
    groups: impl Into<Option<i32>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

2D transposed convolution over an input with several channels.

Only the default groups=1 is currently supported.

The numeric parameters may be given as single values:

ยงParams

  • array: input array of shape [N, H, W, C_in]
  • weight: weight array of shape [C_out, H, W, C_in]
  • stride: kernel stride. Default to (1, 1) if not specified.
  • padding: input padding. Default to (0, 0) if not specified.
  • dilation: kernel dilation. Default to (1, 1) if not specified.
  • groups: input feature groups. Default to 1 if not specified.
  • stream: stream or device to evaluate on.