mlx_rs/ops/sort.rs
1//! Implements bindings for the sorting ops.
2
3use mlx_internal_macros::{default_device, generate_macro};
4
5use crate::{error::Result, utils::guard::Guarded, Array, Stream};
6
7/// Returns a sorted copy of the array. Returns an error if the arguments are invalid.
8///
9/// # Params
10///
11/// - `array`: input array
12/// - `axis`: axis to sort over
13///
14/// # Example
15///
16/// ```rust
17/// use mlx_rs::{Array, ops::*};
18///
19/// let a = Array::from_slice(&[3, 2, 1], &[3]);
20/// let axis = 0;
21/// let result = sort_axis(&a, axis);
22/// ```
23#[generate_macro]
24#[default_device]
25pub fn sort_axis_device(
26 a: impl AsRef<Array>,
27 axis: i32,
28 #[optional] stream: impl AsRef<Stream>,
29) -> Result<Array> {
30 Array::try_from_op(|res| unsafe {
31 mlx_sys::mlx_sort_axis(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
32 })
33}
34
35/// Returns a sorted copy of the flattened array. Returns an error if the arguments are invalid.
36///
37/// # Params
38///
39/// - `array`: input array
40///
41/// # Example
42///
43/// ```rust
44/// use mlx_rs::{Array, ops::*};
45///
46/// let a = Array::from_slice(&[3, 2, 1], &[3]);
47/// let result = sort(&a);
48/// ```
49#[generate_macro]
50#[default_device]
51pub fn sort_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
52 Array::try_from_op(|res| unsafe {
53 mlx_sys::mlx_sort(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
54 })
55}
56
57/// Returns the indices that sort the array. Returns an error if the arguments are invalid.
58///
59/// # Params
60///
61/// - `a`: The array to sort.
62/// - `axis`: axis to sort over
63///
64/// # Example
65///
66/// ```rust
67/// use mlx_rs::{Array, ops::*};
68///
69/// let a = Array::from_slice(&[3, 2, 1], &[3]);
70/// let axis = 0;
71/// let result = argsort_axis(&a, axis);
72/// ```
73#[generate_macro]
74#[default_device]
75pub fn argsort_axis_device(
76 a: impl AsRef<Array>,
77 axis: i32,
78 #[optional] stream: impl AsRef<Stream>,
79) -> Result<Array> {
80 Array::try_from_op(|res| unsafe {
81 mlx_sys::mlx_argsort_axis(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
82 })
83}
84
85/// Returns the indices that sort the flattened array. Returns an error if the arguments are
86/// invalid.
87///
88/// # Params
89///
90/// - `a`: The array to sort.
91///
92/// # Example
93///
94/// ```rust
95/// use mlx_rs::{Array, ops::*};
96///
97/// let a = Array::from_slice(&[3, 2, 1], &[3]);
98/// let result = argsort(&a);
99/// ```
100#[generate_macro]
101#[default_device]
102pub fn argsort_device(
103 a: impl AsRef<Array>,
104 #[optional] stream: impl AsRef<Stream>,
105) -> Result<Array> {
106 Array::try_from_op(|res| unsafe {
107 mlx_sys::mlx_argsort(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
108 })
109}
110
111/// Returns a partitioned copy of the array such that the smaller `kth` elements are first.
112/// Returns an error if the arguments are invalid.
113///
114/// The ordering of the elements in partitions is undefined.
115///
116/// # Params
117///
118/// - `array`: input array
119/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
120/// before the kth index will be less or equal to the `kth` element and all elements after will be
121/// greater or equal to the `kth` element in the output.
122/// - `axis`: axis to partition over
123///
124/// # Example
125///
126/// ```rust
127/// use mlx_rs::{Array, ops::*};
128///
129/// let a = Array::from_slice(&[3, 2, 1], &[3]);
130/// let kth = 1;
131/// let axis = 0;
132/// let result = partition_axis(&a, kth, axis);
133/// ```
134#[generate_macro]
135#[default_device]
136pub fn partition_axis_device(
137 a: impl AsRef<Array>,
138 kth: i32,
139 axis: i32,
140 #[optional] stream: impl AsRef<Stream>,
141) -> Result<Array> {
142 Array::try_from_op(|res| unsafe {
143 mlx_sys::mlx_partition_axis(
144 res,
145 a.as_ref().as_ptr(),
146 kth,
147 axis,
148 stream.as_ref().as_ptr(),
149 )
150 })
151}
152
153/// Returns a partitioned copy of the flattened array such that the smaller `kth` elements are
154/// first. Returns an error if the arguments are invalid.
155///
156/// The ordering of the elements in partitions is undefined.
157///
158/// # Params
159///
160/// - `array`: input array
161/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
162/// before the kth index will be less or equal to the `kth` element and all elements after will be
163/// greater or equal to the `kth` element in the output.
164///
165/// # Example
166///
167/// ```rust
168/// use mlx_rs::{Array, ops::*};
169///
170/// let a = Array::from_slice(&[3, 2, 1], &[3]);
171/// let kth = 1;
172/// let result = partition(&a, kth);
173/// ```
174#[generate_macro]
175#[default_device]
176pub fn partition_device(
177 a: impl AsRef<Array>,
178 kth: i32,
179 #[optional] stream: impl AsRef<Stream>,
180) -> Result<Array> {
181 Array::try_from_op(|res| unsafe {
182 mlx_sys::mlx_partition(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
183 })
184}
185
186/// Returns the indices that partition the array. Returns an error if the arguments are invalid.
187///
188/// The ordering of the elements within a partition in given by the indices is undefined.
189///
190/// # Params
191///
192/// - `a`: The array to sort.
193/// - `kth`: element index at the `kth` position in the output will give the sorted position. All
194/// indices before the`kth` position will be of elements less than or equal to the element at the
195/// `kth` index and all indices after will be elemenents greater than or equal to the element at
196/// the `kth` position.
197/// - `axis`: axis to partition over
198///
199/// # Example
200///
201/// ```rust
202/// use mlx_rs::{Array, ops::*};
203///
204/// let a = Array::from_slice(&[3, 2, 1], &[3]);
205/// let kth = 1;
206/// let axis = 0;
207/// let result = argpartition_axis(&a, kth, axis);
208/// ```
209#[generate_macro]
210#[default_device]
211pub fn argpartition_axis_device(
212 a: impl AsRef<Array>,
213 kth: i32,
214 axis: i32,
215 #[optional] stream: impl AsRef<Stream>,
216) -> Result<Array> {
217 Array::try_from_op(|res| unsafe {
218 mlx_sys::mlx_argpartition_axis(
219 res,
220 a.as_ref().as_ptr(),
221 kth,
222 axis,
223 stream.as_ref().as_ptr(),
224 )
225 })
226}
227
228/// Returns the indices that partition the flattened array. Returns an error if the arguments are
229/// invalid.
230///
231/// The ordering of the elements within a partition in given by the indices is undefined.
232///
233/// # Params
234///
235/// - `a`: The array to sort.
236/// - `kth`: element index at the `kth` position in the output will give the sorted position. All
237/// indices before the`kth` position will be of elements less than or equal to the element at the
238/// `kth` index and all indices after will be elemenents greater than or equal to the element at
239/// the `kth` position.
240///
241/// # Example
242///
243/// ```rust
244/// use mlx_rs::{Array, ops::*};
245///
246/// let a = Array::from_slice(&[3, 2, 1], &[3]);
247/// let kth = 1;
248/// let result = argpartition(&a, kth);
249/// ```
250#[generate_macro]
251#[default_device]
252pub fn argpartition_device(
253 a: impl AsRef<Array>,
254 kth: i32,
255 #[optional] stream: impl AsRef<Stream>,
256) -> Result<Array> {
257 Array::try_from_op(|res| unsafe {
258 mlx_sys::mlx_argpartition(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
259 })
260}
261
262#[cfg(test)]
263mod tests {
264 use crate::Array;
265
266 #[test]
267 fn test_sort_with_invalid_axis() {
268 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
269 let axis = 1;
270 let result = super::sort_axis(&a, axis);
271 assert!(result.is_err());
272 }
273
274 #[test]
275 fn test_partition_with_invalid_axis() {
276 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
277 let kth = 2;
278 let axis = 1;
279 let result = super::partition_axis(&a, kth, axis);
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_partition_with_invalid_kth() {
285 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
286 let kth = 5;
287 let axis = 0;
288 let result = super::partition_axis(&a, kth, axis);
289 assert!(result.is_err());
290 }
291
292 #[test]
293 fn test_partition_all_with_invalid_kth() {
294 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
295 let kth = 5;
296 let result = super::partition(&a, kth);
297 assert!(result.is_err());
298 }
299}