mlx_rs/
nested.rs

1//! Implements a nested hashmap
2
3use std::{collections::HashMap, fmt::Display, rc::Rc};
4
5const DELIMITER: char = '.';
6
7/// A nested value that can be either a value or a map of nested values
8#[derive(Debug, Clone)]
9pub enum NestedValue<K, T> {
10    /// A value
11    Value(T),
12
13    /// A map of nested values
14    Map(HashMap<K, NestedValue<K, T>>),
15}
16
17impl<K, V> NestedValue<K, V> {
18    /// Flattens the nested value into a hashmap
19    pub fn flatten(self, prefix: &str) -> HashMap<Rc<str>, V>
20    where
21        K: Display,
22    {
23        match self {
24            NestedValue::Value(array) => {
25                let mut map = HashMap::new();
26                map.insert(prefix.into(), array);
27                map
28            }
29            NestedValue::Map(entries) => entries
30                .into_iter()
31                .flat_map(|(key, value)| value.flatten(&format!("{}{}{}", prefix, DELIMITER, key)))
32                .collect(),
33        }
34    }
35}
36
37/// A nested hashmap
38#[derive(Debug, Clone)]
39pub struct NestedHashMap<K, V> {
40    /// The internal hashmap
41    pub entries: HashMap<K, NestedValue<K, V>>,
42}
43
44impl<K, V> From<NestedHashMap<K, V>> for NestedValue<K, V> {
45    fn from(map: NestedHashMap<K, V>) -> Self {
46        NestedValue::Map(map.entries)
47    }
48}
49
50impl<K, V> Default for NestedHashMap<K, V> {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl<K, V> NestedHashMap<K, V> {
57    /// Creates a new nested hashmap
58    pub fn new() -> Self {
59        Self {
60            entries: HashMap::new(),
61        }
62    }
63
64    /// Inserts a new entry into the nested hashmap
65    pub fn insert(&mut self, key: K, value: NestedValue<K, V>)
66    where
67        K: Eq + std::hash::Hash,
68    {
69        self.entries.insert(key, value);
70    }
71
72    /// Flattens the nested hashmap into a hashmap
73    pub fn flatten(self) -> HashMap<Rc<str>, V>
74    where
75        K: AsRef<str> + Display,
76    {
77        self.entries
78            .into_iter()
79            .flat_map(|(key, value)| value.flatten(key.as_ref()))
80            .collect()
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use crate::array;
87
88    use super::*;
89
90    #[test]
91    fn test_flatten_nested_hash_map_of_owned_arrays() {
92        let first_entry = NestedValue::Value(array!([1, 2, 3]));
93        let second_entry = NestedValue::Map({
94            let mut map = HashMap::new();
95            map.insert("a", NestedValue::Value(array!([4, 5, 6])));
96            map.insert("b", NestedValue::Value(array!([7, 8, 9])));
97            map
98        });
99
100        let map = NestedHashMap {
101            entries: {
102                let mut map = HashMap::new();
103                map.insert("first", first_entry);
104                map.insert("second", second_entry);
105                map
106            },
107        };
108
109        let flattened = map.flatten();
110
111        assert_eq!(flattened.len(), 3);
112        assert_eq!(flattened["first"], array!([1, 2, 3]));
113        assert_eq!(flattened["second.a"], array!([4, 5, 6]));
114        assert_eq!(flattened["second.b"], array!([7, 8, 9]));
115    }
116
117    #[test]
118    fn test_flatten_nested_hash_map_of_borrowed_arrays() {
119        let first_entry_content = array!([1, 2, 3]);
120        let first_entry = NestedValue::Value(&first_entry_content);
121
122        let second_entry_content_a = array!([4, 5, 6]);
123        let second_entry_content_b = array!([7, 8, 9]);
124        let second_entry = NestedValue::Map({
125            let mut map = HashMap::new();
126            map.insert("a", NestedValue::Value(&second_entry_content_a));
127            map.insert("b", NestedValue::Value(&second_entry_content_b));
128            map
129        });
130
131        let map = NestedHashMap {
132            entries: {
133                let mut map = HashMap::new();
134                map.insert("first", first_entry);
135                map.insert("second", second_entry);
136                map
137            },
138        };
139
140        let flattened = map.flatten();
141
142        assert_eq!(flattened.len(), 3);
143        assert_eq!(flattened["first"], &first_entry_content);
144        assert_eq!(flattened["second.a"], &second_entry_content_a);
145        assert_eq!(flattened["second.b"], &second_entry_content_b);
146    }
147
148    #[test]
149    fn test_flatten_nested_hash_map_of_mut_borrowed_arrays() {
150        let mut first_entry_content = array!([1, 2, 3]);
151        let first_entry = NestedValue::Value(&mut first_entry_content);
152
153        let mut second_entry_content_a = array!([4, 5, 6]);
154        let mut second_entry_content_b = array!([7, 8, 9]);
155        let second_entry = NestedValue::Map({
156            let mut map = HashMap::new();
157            map.insert("a", NestedValue::Value(&mut second_entry_content_a));
158            map.insert("b", NestedValue::Value(&mut second_entry_content_b));
159            map
160        });
161
162        let map = NestedHashMap {
163            entries: {
164                let mut map = HashMap::new();
165                map.insert("first", first_entry);
166                map.insert("second", second_entry);
167                map
168            },
169        };
170
171        let flattened = map.flatten();
172
173        assert_eq!(flattened.len(), 3);
174        assert_eq!(flattened["first"], &mut array!([1, 2, 3]));
175        assert_eq!(flattened["second.a"], &mut array!([4, 5, 6]));
176        assert_eq!(flattened["second.b"], &mut array!([7, 8, 9]));
177    }
178
179    #[test]
180    fn test_flatten_empty_nested_hash_map() {
181        let map = NestedHashMap::<&str, i32>::new();
182        let flattened = map.flatten();
183
184        assert!(flattened.is_empty());
185
186        // Insert another empty map
187        let mut map = NestedHashMap::<&str, i32>::new();
188        let empty_map = NestedValue::Map(HashMap::new());
189        map.insert("empty", empty_map);
190
191        let flattened = map.flatten();
192        assert!(flattened.is_empty());
193    }
194}