mlx_rs/module/
param.rs

1use std::{
2    collections::HashMap,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use crate::{nested::NestedValue, Array};
8
9use super::ModuleParameters;
10
11/// Trait for a module parameter.
12pub trait Parameter {
13    /// Freeze the parameter.
14    fn freeze(&mut self, recursive: bool);
15
16    /// Unfreeze the parameter.
17    fn unfreeze(&mut self, recursive: bool);
18
19    /// Check if the parameter is frozen. Returns `None` if the parameter is a module that has no
20    /// parameters.
21    fn is_frozen(&self) -> Option<bool>;
22
23    /// Get the parameter as a nested value.
24    fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>;
25
26    /// Get the parameter as a mutable nested value.
27    fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>;
28
29    /// Get the parameter as a nested value if it is trainable.
30    fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>;
31}
32
33/// A simple wrapper for a module parameter.
34#[derive(Debug, Clone)]
35pub struct Param<T> {
36    /// The value of the parameter.
37    pub value: T,
38
39    /// Whether the parameter is frozen.
40    ///
41    /// This is no longer public because it should be accessed through the `Parameter` trait.
42    is_frozen: bool,
43}
44
45impl<T> Param<T> {
46    /// Create a new `Param`
47    pub fn new(value: T) -> Self {
48        Self {
49            value,
50            is_frozen: false,
51        }
52    }
53}
54
55impl<T> From<T> for Param<T> {
56    fn from(inner: T) -> Self {
57        Self::new(inner)
58    }
59}
60
61impl<T> Deref for Param<T> {
62    type Target = T;
63
64    fn deref(&self) -> &Self::Target {
65        &self.value
66    }
67}
68
69impl<T> DerefMut for Param<T> {
70    fn deref_mut(&mut self) -> &mut Self::Target {
71        &mut self.value
72    }
73}
74
75impl<T> AsRef<T> for Param<T> {
76    fn as_ref(&self) -> &T {
77        &self.value
78    }
79}
80
81impl<T> AsMut<T> for Param<T> {
82    fn as_mut(&mut self) -> &mut T {
83        &mut self.value
84    }
85}
86
87impl Parameter for Param<Array> {
88    fn freeze(&mut self, _recursive: bool) {
89        self.is_frozen = true;
90    }
91
92    fn unfreeze(&mut self, _recursive: bool) {
93        self.is_frozen = false;
94    }
95
96    fn is_frozen(&self) -> Option<bool> {
97        Some(self.is_frozen)
98    }
99
100    fn as_nested_value<'a>(&self) -> NestedValue<Rc<str>, &Array> {
101        NestedValue::Value(&self.value)
102    }
103
104    fn as_nested_value_mut<'a>(&mut self) -> NestedValue<Rc<str>, &mut Array> {
105        NestedValue::Value(&mut self.value)
106    }
107
108    fn as_trainable_nested_value<'a>(&self) -> Option<NestedValue<Rc<str>, &Array>> {
109        match self.is_frozen {
110            true => None,
111            false => Some(NestedValue::Value(&self.value)),
112        }
113    }
114}
115
116impl Parameter for Param<Option<Array>> {
117    fn freeze(&mut self, _recursive: bool) {
118        self.is_frozen = true;
119    }
120
121    fn unfreeze(&mut self, _recursive: bool) {
122        self.is_frozen = false;
123    }
124
125    fn is_frozen(&self) -> Option<bool> {
126        Some(self.is_frozen)
127    }
128
129    fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
130        match &self.value {
131            Some(array) => NestedValue::Value(array),
132            // An empty map entry will be ignored during flattening
133            None => NestedValue::Map(HashMap::with_capacity(0)),
134        }
135    }
136
137    fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
138        match &mut self.value {
139            Some(array) => NestedValue::Value(array),
140            // An empty map entry will be ignored during flattening
141            None => NestedValue::Map(HashMap::with_capacity(0)),
142        }
143    }
144
145    fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
146        match self.is_frozen {
147            true => None,
148            false => self.value.as_ref().map(NestedValue::Value),
149        }
150    }
151}
152
153impl<T> Parameter for T
154where
155    T: ModuleParameters,
156{
157    fn freeze(&mut self, recursive: bool) {
158        self.freeze_parameters(recursive);
159    }
160
161    fn unfreeze(&mut self, recursive: bool) {
162        self.unfreeze_parameters(recursive);
163    }
164
165    fn is_frozen(&self) -> Option<bool> {
166        self.all_frozen()
167    }
168
169    fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
170        self.parameters().into()
171    }
172
173    fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
174        self.parameters_mut().into()
175    }
176
177    fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
178        Some(self.trainable_parameters().into())
179    }
180}