mlx_rs::fast

Function rms_norm_device

Source
pub fn rms_norm_device(
    x: impl AsRef<Array>,
    weight: impl AsRef<Array>,
    eps: f32,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Root Mean Square normalization (RMS norm).

The normalization is with respect to the last axis of the input x.

ยงParams

  • x: input array
  • weight: A multiplicative weight to scale the result by. The weight should be one-dimensional with the same size as the last axis of x.
  • eps: A small additive constant for numerical stability
  • stream: stream or device to evaluate on