pub fn rms_norm(
x: impl AsRef<Array>,
weight: impl AsRef<Array>,
eps: f32,
) -> Result<Array>
Expand description
Root Mean Square normalization (RMS norm).
The normalization is with respect to the last axis of the input x
.
ยงParams
- x: input array
- weight: A multiplicative weight to scale the result by. The
weight
should be one-dimensional with the same size as the last axis ofx
. - eps: A small additive constant for numerical stability
- stream: stream or device to evaluate on