mlx_rs

Function stop_gradient_device

Source
pub fn stop_gradient_device(
    a: impl AsRef<Array>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Stop gradients from being computed.

The operation is the identity but it prevents gradients from flowing through the array.