pub fn clip_grad_norm(
gradients: &FlattenedModuleParam,
max_norm: f32,
) -> Result<(MaybeClippedGrads<'_>, f32)>
Expand description
Clips the global norm of the gradients
This function ensures that the global norm of the gradients does not exceed
max_norm
. It scales down the gradients proportionally if their norm is
greater than max_norm
.