mlx_rs::optimizers

Function clip_grad_norm

Source
pub fn clip_grad_norm(
    gradients: &FlattenedModuleParam,
    max_norm: f32,
) -> Result<(MaybeClippedGrads<'_>, f32)>
Expand description

Clips the global norm of the gradients

This function ensures that the global norm of the gradients does not exceed max_norm. It scales down the gradients proportionally if their norm is greater than max_norm.