mlx_rs/ops/
sort.rs

1//! Implements bindings for the sorting ops.
2
3use mlx_internal_macros::{default_device, generate_macro};
4
5use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
6
7/// Returns a sorted copy of the array. Returns an error if the arguments are invalid.
8///
9/// # Params
10///
11/// - `array`: input array
12/// - `axis`: axis to sort over
13///
14/// # Example
15///
16/// ```rust
17/// use mlx_rs::{Array, ops::*};
18///
19/// let a = Array::from_slice(&[3, 2, 1], &[3]);
20/// let axis = 0;
21/// let result = sort(&a, axis);
22/// ```
23#[generate_macro]
24#[default_device]
25pub fn sort_device(
26    a: impl AsRef<Array>,
27    axis: i32,
28    #[optional] stream: impl AsRef<Stream>,
29) -> Result<Array> {
30    Array::try_from_op(|res| unsafe {
31        mlx_sys::mlx_sort(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
32    })
33}
34
35/// Returns a sorted copy of the flattened array. Returns an error if the arguments are invalid.
36///
37/// # Params
38///
39/// - `array`: input array
40///
41/// # Example
42///
43/// ```rust
44/// use mlx_rs::{Array, ops::*};
45///
46/// let a = Array::from_slice(&[3, 2, 1], &[3]);
47/// let result = sort_all(&a);
48/// ```
49#[generate_macro]
50#[default_device]
51pub fn sort_all_device(
52    a: impl AsRef<Array>,
53    #[optional] stream: impl AsRef<Stream>,
54) -> Result<Array> {
55    Array::try_from_op(|res| unsafe {
56        mlx_sys::mlx_sort_all(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
57    })
58}
59
60/// Returns the indices that sort the array. Returns an error if the arguments are invalid.
61///
62/// # Params
63///
64/// - `a`: The array to sort.
65/// - `axis`: axis to sort over
66///
67/// # Example
68///
69/// ```rust
70/// use mlx_rs::{Array, ops::*};
71///
72/// let a = Array::from_slice(&[3, 2, 1], &[3]);
73/// let axis = 0;
74/// let result = argsort(&a, axis);
75/// ```
76#[generate_macro]
77#[default_device]
78pub fn argsort_device(
79    a: impl AsRef<Array>,
80    axis: i32,
81    #[optional] stream: impl AsRef<Stream>,
82) -> Result<Array> {
83    Array::try_from_op(|res| unsafe {
84        mlx_sys::mlx_argsort(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
85    })
86}
87
88/// Returns the indices that sort the flattened array. Returns an error if the arguments are
89/// invalid.
90///
91/// # Params
92///
93/// - `a`: The array to sort.
94///
95/// # Example
96///
97/// ```rust
98/// use mlx_rs::{Array, ops::*};
99///
100/// let a = Array::from_slice(&[3, 2, 1], &[3]);
101/// let result = argsort_all(&a);
102/// ```
103#[generate_macro]
104#[default_device]
105pub fn argsort_all_device(
106    a: impl AsRef<Array>,
107    #[optional] stream: impl AsRef<Stream>,
108) -> Result<Array> {
109    Array::try_from_op(|res| unsafe {
110        mlx_sys::mlx_argsort_all(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
111    })
112}
113
114/// Returns a partitioned copy of the array such that the smaller `kth` elements are first.
115/// Returns an error if the arguments are invalid.
116///
117/// The ordering of the elements in partitions is undefined.
118///
119/// # Params
120///
121/// - `array`: input array
122/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
123///   before the kth index will be less or equal to the `kth` element and all elements after will be
124///   greater or equal to the `kth` element in the output.
125/// - `axis`: axis to partition over
126///
127/// # Example
128///
129/// ```rust
130/// use mlx_rs::{Array, ops::*};
131///
132/// let a = Array::from_slice(&[3, 2, 1], &[3]);
133/// let kth = 1;
134/// let axis = 0;
135/// let result = partition(&a, kth, axis);
136/// ```
137#[generate_macro]
138#[default_device]
139pub fn partition_device(
140    a: impl AsRef<Array>,
141    kth: i32,
142    axis: i32,
143    #[optional] stream: impl AsRef<Stream>,
144) -> Result<Array> {
145    Array::try_from_op(|res| unsafe {
146        mlx_sys::mlx_partition(
147            res,
148            a.as_ref().as_ptr(),
149            kth,
150            axis,
151            stream.as_ref().as_ptr(),
152        )
153    })
154}
155
156/// Returns a partitioned copy of the flattened array such that the smaller `kth` elements are
157/// first. Returns an error if the arguments are invalid.
158///
159/// The ordering of the elements in partitions is undefined.
160///
161/// # Params
162///
163/// - `array`: input array
164/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
165///   before the kth index will be less or equal to the `kth` element and all elements after will be
166///   greater or equal to the `kth` element in the output.
167///
168/// # Example
169///
170/// ```rust
171/// use mlx_rs::{Array, ops::*};
172///
173/// let a = Array::from_slice(&[3, 2, 1], &[3]);
174/// let kth = 1;
175/// let result = partition_all(&a, kth);
176/// ```
177#[generate_macro]
178#[default_device]
179pub fn partition_all_device(
180    a: impl AsRef<Array>,
181    kth: i32,
182    #[optional] stream: impl AsRef<Stream>,
183) -> Result<Array> {
184    Array::try_from_op(|res| unsafe {
185        mlx_sys::mlx_partition_all(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
186    })
187}
188
189/// Returns the indices that partition the array. Returns an error if the arguments are invalid.
190///
191/// The ordering of the elements within a partition in given by the indices is undefined.
192///
193/// # Params
194///
195/// - `a`: The array to sort.
196/// - `kth`: element index at the `kth` position in the output will give the sorted position.  All
197///   indices before the`kth` position will be of elements less than or equal to the element at the
198///   `kth` index and all indices after will be elemenents greater than or equal to the element at
199///   the `kth` position.
200/// - `axis`: axis to partition over
201///
202/// # Example
203///
204/// ```rust
205/// use mlx_rs::{Array, ops::*};
206///
207/// let a = Array::from_slice(&[3, 2, 1], &[3]);
208/// let kth = 1;
209/// let axis = 0;
210/// let result = argpartition(&a, kth, axis);
211/// ```
212#[generate_macro]
213#[default_device]
214pub fn argpartition_device(
215    a: impl AsRef<Array>,
216    kth: i32,
217    axis: i32,
218    #[optional] stream: impl AsRef<Stream>,
219) -> Result<Array> {
220    Array::try_from_op(|res| unsafe {
221        mlx_sys::mlx_argpartition(
222            res,
223            a.as_ref().as_ptr(),
224            kth,
225            axis,
226            stream.as_ref().as_ptr(),
227        )
228    })
229}
230
231/// Returns the indices that partition the flattened array. Returns an error if the arguments are
232/// invalid.
233///
234/// The ordering of the elements within a partition in given by the indices is undefined.
235///
236/// # Params
237///
238/// - `a`: The array to sort.
239/// - `kth`: element index at the `kth` position in the output will give the sorted position.  All
240///   indices before the`kth` position will be of elements less than or equal to the element at the
241///   `kth` index and all indices after will be elemenents greater than or equal to the element at
242///   the `kth` position.
243///
244/// # Example
245///
246/// ```rust
247/// use mlx_rs::{Array, ops::*};
248///
249/// let a = Array::from_slice(&[3, 2, 1], &[3]);
250/// let kth = 1;
251/// let result = argpartition_all(&a, kth);
252/// ```
253#[generate_macro]
254#[default_device]
255pub fn argpartition_all_device(
256    a: impl AsRef<Array>,
257    kth: i32,
258    #[optional] stream: impl AsRef<Stream>,
259) -> Result<Array> {
260    Array::try_from_op(|res| unsafe {
261        mlx_sys::mlx_argpartition_all(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
262    })
263}
264
265#[cfg(test)]
266mod tests {
267    use crate::Array;
268
269    #[test]
270    fn test_sort_with_invalid_axis() {
271        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
272        let axis = 1;
273        let result = super::sort(&a, axis);
274        assert!(result.is_err());
275    }
276
277    #[test]
278    fn test_partition_with_invalid_axis() {
279        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
280        let kth = 2;
281        let axis = 1;
282        let result = super::partition(&a, kth, axis);
283        assert!(result.is_err());
284    }
285
286    #[test]
287    fn test_partition_with_invalid_kth() {
288        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
289        let kth = 5;
290        let axis = 0;
291        let result = super::partition(&a, kth, axis);
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_partition_all_with_invalid_kth() {
297        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
298        let kth = 5;
299        let result = super::partition_all(&a, kth);
300        assert!(result.is_err());
301    }
302}