mlx_rs/module/
param.rs

1use std::{
2    collections::HashMap,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use crate::{nested::NestedValue, Array};
8
9use super::ModuleParameters;
10
11/// Trait for a module parameter.
12pub trait Parameter {
13    /// Total number of parameters in this module/parameter.
14    fn count(&self) -> usize;
15
16    /// Freeze the parameter.
17    fn freeze(&mut self, recursive: bool);
18
19    /// Unfreeze the parameter.
20    fn unfreeze(&mut self, recursive: bool);
21
22    /// Check if the parameter is frozen. Returns `None` if the parameter is a module that has no
23    /// parameters.
24    fn is_frozen(&self) -> Option<bool>;
25
26    /// Get the parameter as a nested value.
27    fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>;
28
29    /// Get the parameter as a mutable nested value.
30    fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>;
31
32    /// Get the parameter as a nested value if it is trainable.
33    fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>;
34}
35
36/// A simple wrapper for a module parameter.
37#[derive(Debug, Clone)]
38pub struct Param<T> {
39    /// The value of the parameter.
40    pub value: T,
41
42    /// Whether the parameter is frozen.
43    ///
44    /// This is no longer public because it should be accessed through the `Parameter` trait.
45    is_frozen: bool,
46}
47
48impl<T> Param<T> {
49    /// Create a new `Param`
50    pub fn new(value: T) -> Self {
51        Self {
52            value,
53            is_frozen: false,
54        }
55    }
56}
57
58impl<T> From<T> for Param<T> {
59    fn from(inner: T) -> Self {
60        Self::new(inner)
61    }
62}
63
64impl<T> Deref for Param<T> {
65    type Target = T;
66
67    fn deref(&self) -> &Self::Target {
68        &self.value
69    }
70}
71
72impl<T> DerefMut for Param<T> {
73    fn deref_mut(&mut self) -> &mut Self::Target {
74        &mut self.value
75    }
76}
77
78impl<T> AsRef<T> for Param<T> {
79    fn as_ref(&self) -> &T {
80        &self.value
81    }
82}
83
84impl<T> AsMut<T> for Param<T> {
85    fn as_mut(&mut self) -> &mut T {
86        &mut self.value
87    }
88}
89
90impl Parameter for Param<Array> {
91    fn count(&self) -> usize {
92        1
93    }
94
95    fn freeze(&mut self, _recursive: bool) {
96        self.is_frozen = true;
97    }
98
99    fn unfreeze(&mut self, _recursive: bool) {
100        self.is_frozen = false;
101    }
102
103    fn is_frozen(&self) -> Option<bool> {
104        Some(self.is_frozen)
105    }
106
107    fn as_nested_value<'a>(&self) -> NestedValue<Rc<str>, &Array> {
108        NestedValue::Value(&self.value)
109    }
110
111    fn as_nested_value_mut<'a>(&mut self) -> NestedValue<Rc<str>, &mut Array> {
112        NestedValue::Value(&mut self.value)
113    }
114
115    fn as_trainable_nested_value<'a>(&self) -> Option<NestedValue<Rc<str>, &Array>> {
116        match self.is_frozen {
117            true => None,
118            false => Some(NestedValue::Value(&self.value)),
119        }
120    }
121}
122
123impl Parameter for Param<Option<Array>> {
124    fn count(&self) -> usize {
125        self.value.as_ref().map_or(0, |_| 1)
126    }
127
128    fn freeze(&mut self, _recursive: bool) {
129        self.is_frozen = true;
130    }
131
132    fn unfreeze(&mut self, _recursive: bool) {
133        self.is_frozen = false;
134    }
135
136    fn is_frozen(&self) -> Option<bool> {
137        Some(self.is_frozen)
138    }
139
140    fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
141        match &self.value {
142            Some(array) => NestedValue::Value(array),
143            // An empty map entry will be ignored during flattening
144            None => NestedValue::Map(HashMap::with_capacity(0)),
145        }
146    }
147
148    fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
149        match &mut self.value {
150            Some(array) => NestedValue::Value(array),
151            // An empty map entry will be ignored during flattening
152            None => NestedValue::Map(HashMap::with_capacity(0)),
153        }
154    }
155
156    fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
157        match self.is_frozen {
158            true => None,
159            false => self.value.as_ref().map(NestedValue::Value),
160        }
161    }
162}
163
164impl<T> Parameter for T
165where
166    T: ModuleParameters,
167{
168    fn count(&self) -> usize {
169        self.num_parameters()
170    }
171
172    fn freeze(&mut self, recursive: bool) {
173        self.freeze_parameters(recursive);
174    }
175
176    fn unfreeze(&mut self, recursive: bool) {
177        self.unfreeze_parameters(recursive);
178    }
179
180    fn is_frozen(&self) -> Option<bool> {
181        self.all_frozen()
182    }
183
184    fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
185        self.parameters().into()
186    }
187
188    fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
189        self.parameters_mut().into()
190    }
191
192    fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
193        Some(self.trainable_parameters().into())
194    }
195}