mlx_rs/ops/
sort.rs

1//! Implements bindings for the sorting ops.
2
3use mlx_internal_macros::{default_device, generate_macro};
4
5use crate::{error::Result, utils::guard::Guarded, Array, Stream};
6
7/// Returns a sorted copy of the array. Returns an error if the arguments are invalid.
8///
9/// # Params
10///
11/// - `array`: input array
12/// - `axis`: axis to sort over
13///
14/// # Example
15///
16/// ```rust
17/// use mlx_rs::{Array, ops::*};
18///
19/// let a = Array::from_slice(&[3, 2, 1], &[3]);
20/// let axis = 0;
21/// let result = sort_axis(&a, axis);
22/// ```
23#[generate_macro]
24#[default_device]
25pub fn sort_axis_device(
26    a: impl AsRef<Array>,
27    axis: i32,
28    #[optional] stream: impl AsRef<Stream>,
29) -> Result<Array> {
30    Array::try_from_op(|res| unsafe {
31        mlx_sys::mlx_sort_axis(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
32    })
33}
34
35/// Returns a sorted copy of the flattened array. Returns an error if the arguments are invalid.
36///
37/// # Params
38///
39/// - `array`: input array
40///
41/// # Example
42///
43/// ```rust
44/// use mlx_rs::{Array, ops::*};
45///
46/// let a = Array::from_slice(&[3, 2, 1], &[3]);
47/// let result = sort(&a);
48/// ```
49#[generate_macro]
50#[default_device]
51pub fn sort_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
52    Array::try_from_op(|res| unsafe {
53        mlx_sys::mlx_sort(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
54    })
55}
56
57/// Returns the indices that sort the array. Returns an error if the arguments are invalid.
58///
59/// # Params
60///
61/// - `a`: The array to sort.
62/// - `axis`: axis to sort over
63///
64/// # Example
65///
66/// ```rust
67/// use mlx_rs::{Array, ops::*};
68///
69/// let a = Array::from_slice(&[3, 2, 1], &[3]);
70/// let axis = 0;
71/// let result = argsort_axis(&a, axis);
72/// ```
73#[generate_macro]
74#[default_device]
75pub fn argsort_axis_device(
76    a: impl AsRef<Array>,
77    axis: i32,
78    #[optional] stream: impl AsRef<Stream>,
79) -> Result<Array> {
80    Array::try_from_op(|res| unsafe {
81        mlx_sys::mlx_argsort_axis(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
82    })
83}
84
85/// Returns the indices that sort the flattened array. Returns an error if the arguments are
86/// invalid.
87///
88/// # Params
89///
90/// - `a`: The array to sort.
91///
92/// # Example
93///
94/// ```rust
95/// use mlx_rs::{Array, ops::*};
96///
97/// let a = Array::from_slice(&[3, 2, 1], &[3]);
98/// let result = argsort(&a);
99/// ```
100#[generate_macro]
101#[default_device]
102pub fn argsort_device(
103    a: impl AsRef<Array>,
104    #[optional] stream: impl AsRef<Stream>,
105) -> Result<Array> {
106    Array::try_from_op(|res| unsafe {
107        mlx_sys::mlx_argsort(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
108    })
109}
110
111/// Returns a partitioned copy of the array such that the smaller `kth` elements are first.
112/// Returns an error if the arguments are invalid.
113///
114/// The ordering of the elements in partitions is undefined.
115///
116/// # Params
117///
118/// - `array`: input array
119/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
120///   before the kth index will be less or equal to the `kth` element and all elements after will be
121///   greater or equal to the `kth` element in the output.
122/// - `axis`: axis to partition over
123///
124/// # Example
125///
126/// ```rust
127/// use mlx_rs::{Array, ops::*};
128///
129/// let a = Array::from_slice(&[3, 2, 1], &[3]);
130/// let kth = 1;
131/// let axis = 0;
132/// let result = partition_axis(&a, kth, axis);
133/// ```
134#[generate_macro]
135#[default_device]
136pub fn partition_axis_device(
137    a: impl AsRef<Array>,
138    kth: i32,
139    axis: i32,
140    #[optional] stream: impl AsRef<Stream>,
141) -> Result<Array> {
142    Array::try_from_op(|res| unsafe {
143        mlx_sys::mlx_partition_axis(
144            res,
145            a.as_ref().as_ptr(),
146            kth,
147            axis,
148            stream.as_ref().as_ptr(),
149        )
150    })
151}
152
153/// Returns a partitioned copy of the flattened array such that the smaller `kth` elements are
154/// first. Returns an error if the arguments are invalid.
155///
156/// The ordering of the elements in partitions is undefined.
157///
158/// # Params
159///
160/// - `array`: input array
161/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
162///   before the kth index will be less or equal to the `kth` element and all elements after will be
163///   greater or equal to the `kth` element in the output.
164///
165/// # Example
166///
167/// ```rust
168/// use mlx_rs::{Array, ops::*};
169///
170/// let a = Array::from_slice(&[3, 2, 1], &[3]);
171/// let kth = 1;
172/// let result = partition(&a, kth);
173/// ```
174#[generate_macro]
175#[default_device]
176pub fn partition_device(
177    a: impl AsRef<Array>,
178    kth: i32,
179    #[optional] stream: impl AsRef<Stream>,
180) -> Result<Array> {
181    Array::try_from_op(|res| unsafe {
182        mlx_sys::mlx_partition(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
183    })
184}
185
186/// Returns the indices that partition the array. Returns an error if the arguments are invalid.
187///
188/// The ordering of the elements within a partition in given by the indices is undefined.
189///
190/// # Params
191///
192/// - `a`: The array to sort.
193/// - `kth`: element index at the `kth` position in the output will give the sorted position.  All
194///   indices before the`kth` position will be of elements less than or equal to the element at the
195///   `kth` index and all indices after will be elemenents greater than or equal to the element at
196///   the `kth` position.
197/// - `axis`: axis to partition over
198///
199/// # Example
200///
201/// ```rust
202/// use mlx_rs::{Array, ops::*};
203///
204/// let a = Array::from_slice(&[3, 2, 1], &[3]);
205/// let kth = 1;
206/// let axis = 0;
207/// let result = argpartition_axis(&a, kth, axis);
208/// ```
209#[generate_macro]
210#[default_device]
211pub fn argpartition_axis_device(
212    a: impl AsRef<Array>,
213    kth: i32,
214    axis: i32,
215    #[optional] stream: impl AsRef<Stream>,
216) -> Result<Array> {
217    Array::try_from_op(|res| unsafe {
218        mlx_sys::mlx_argpartition_axis(
219            res,
220            a.as_ref().as_ptr(),
221            kth,
222            axis,
223            stream.as_ref().as_ptr(),
224        )
225    })
226}
227
228/// Returns the indices that partition the flattened array. Returns an error if the arguments are
229/// invalid.
230///
231/// The ordering of the elements within a partition in given by the indices is undefined.
232///
233/// # Params
234///
235/// - `a`: The array to sort.
236/// - `kth`: element index at the `kth` position in the output will give the sorted position.  All
237///   indices before the`kth` position will be of elements less than or equal to the element at the
238///   `kth` index and all indices after will be elemenents greater than or equal to the element at
239///   the `kth` position.
240///
241/// # Example
242///
243/// ```rust
244/// use mlx_rs::{Array, ops::*};
245///
246/// let a = Array::from_slice(&[3, 2, 1], &[3]);
247/// let kth = 1;
248/// let result = argpartition(&a, kth);
249/// ```
250#[generate_macro]
251#[default_device]
252pub fn argpartition_device(
253    a: impl AsRef<Array>,
254    kth: i32,
255    #[optional] stream: impl AsRef<Stream>,
256) -> Result<Array> {
257    Array::try_from_op(|res| unsafe {
258        mlx_sys::mlx_argpartition(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
259    })
260}
261
262#[cfg(test)]
263mod tests {
264    use crate::Array;
265
266    #[test]
267    fn test_sort_with_invalid_axis() {
268        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
269        let axis = 1;
270        let result = super::sort_axis(&a, axis);
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_partition_with_invalid_axis() {
276        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
277        let kth = 2;
278        let axis = 1;
279        let result = super::partition_axis(&a, kth, axis);
280        assert!(result.is_err());
281    }
282
283    #[test]
284    fn test_partition_with_invalid_kth() {
285        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
286        let kth = 5;
287        let axis = 0;
288        let result = super::partition_axis(&a, kth, axis);
289        assert!(result.is_err());
290    }
291
292    #[test]
293    fn test_partition_all_with_invalid_kth() {
294        let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
295        let kth = 5;
296        let result = super::partition(&a, kth);
297        assert!(result.is_err());
298    }
299}