mlx_rs/ops/sort.rs
1//! Implements bindings for the sorting ops.
2
3use mlx_internal_macros::{default_device, generate_macro};
4
5use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
6
7/// Returns a sorted copy of the array. Returns an error if the arguments are invalid.
8///
9/// # Params
10///
11/// - `array`: input array
12/// - `axis`: axis to sort over
13///
14/// # Example
15///
16/// ```rust
17/// use mlx_rs::{Array, ops::*};
18///
19/// let a = Array::from_slice(&[3, 2, 1], &[3]);
20/// let axis = 0;
21/// let result = sort(&a, axis);
22/// ```
23#[generate_macro]
24#[default_device]
25pub fn sort_device(
26 a: impl AsRef<Array>,
27 axis: i32,
28 #[optional] stream: impl AsRef<Stream>,
29) -> Result<Array> {
30 Array::try_from_op(|res| unsafe {
31 mlx_sys::mlx_sort(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
32 })
33}
34
35/// Returns a sorted copy of the flattened array. Returns an error if the arguments are invalid.
36///
37/// # Params
38///
39/// - `array`: input array
40///
41/// # Example
42///
43/// ```rust
44/// use mlx_rs::{Array, ops::*};
45///
46/// let a = Array::from_slice(&[3, 2, 1], &[3]);
47/// let result = sort_all(&a);
48/// ```
49#[generate_macro]
50#[default_device]
51pub fn sort_all_device(
52 a: impl AsRef<Array>,
53 #[optional] stream: impl AsRef<Stream>,
54) -> Result<Array> {
55 Array::try_from_op(|res| unsafe {
56 mlx_sys::mlx_sort_all(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
57 })
58}
59
60/// Returns the indices that sort the array. Returns an error if the arguments are invalid.
61///
62/// # Params
63///
64/// - `a`: The array to sort.
65/// - `axis`: axis to sort over
66///
67/// # Example
68///
69/// ```rust
70/// use mlx_rs::{Array, ops::*};
71///
72/// let a = Array::from_slice(&[3, 2, 1], &[3]);
73/// let axis = 0;
74/// let result = argsort(&a, axis);
75/// ```
76#[generate_macro]
77#[default_device]
78pub fn argsort_device(
79 a: impl AsRef<Array>,
80 axis: i32,
81 #[optional] stream: impl AsRef<Stream>,
82) -> Result<Array> {
83 Array::try_from_op(|res| unsafe {
84 mlx_sys::mlx_argsort(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
85 })
86}
87
88/// Returns the indices that sort the flattened array. Returns an error if the arguments are
89/// invalid.
90///
91/// # Params
92///
93/// - `a`: The array to sort.
94///
95/// # Example
96///
97/// ```rust
98/// use mlx_rs::{Array, ops::*};
99///
100/// let a = Array::from_slice(&[3, 2, 1], &[3]);
101/// let result = argsort_all(&a);
102/// ```
103#[generate_macro]
104#[default_device]
105pub fn argsort_all_device(
106 a: impl AsRef<Array>,
107 #[optional] stream: impl AsRef<Stream>,
108) -> Result<Array> {
109 Array::try_from_op(|res| unsafe {
110 mlx_sys::mlx_argsort_all(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
111 })
112}
113
114/// Returns a partitioned copy of the array such that the smaller `kth` elements are first.
115/// Returns an error if the arguments are invalid.
116///
117/// The ordering of the elements in partitions is undefined.
118///
119/// # Params
120///
121/// - `array`: input array
122/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
123/// before the kth index will be less or equal to the `kth` element and all elements after will be
124/// greater or equal to the `kth` element in the output.
125/// - `axis`: axis to partition over
126///
127/// # Example
128///
129/// ```rust
130/// use mlx_rs::{Array, ops::*};
131///
132/// let a = Array::from_slice(&[3, 2, 1], &[3]);
133/// let kth = 1;
134/// let axis = 0;
135/// let result = partition(&a, kth, axis);
136/// ```
137#[generate_macro]
138#[default_device]
139pub fn partition_device(
140 a: impl AsRef<Array>,
141 kth: i32,
142 axis: i32,
143 #[optional] stream: impl AsRef<Stream>,
144) -> Result<Array> {
145 Array::try_from_op(|res| unsafe {
146 mlx_sys::mlx_partition(
147 res,
148 a.as_ref().as_ptr(),
149 kth,
150 axis,
151 stream.as_ref().as_ptr(),
152 )
153 })
154}
155
156/// Returns a partitioned copy of the flattened array such that the smaller `kth` elements are
157/// first. Returns an error if the arguments are invalid.
158///
159/// The ordering of the elements in partitions is undefined.
160///
161/// # Params
162///
163/// - `array`: input array
164/// - `kth`: Element at the `kth` index will be in its sorted position in the output. All elements
165/// before the kth index will be less or equal to the `kth` element and all elements after will be
166/// greater or equal to the `kth` element in the output.
167///
168/// # Example
169///
170/// ```rust
171/// use mlx_rs::{Array, ops::*};
172///
173/// let a = Array::from_slice(&[3, 2, 1], &[3]);
174/// let kth = 1;
175/// let result = partition_all(&a, kth);
176/// ```
177#[generate_macro]
178#[default_device]
179pub fn partition_all_device(
180 a: impl AsRef<Array>,
181 kth: i32,
182 #[optional] stream: impl AsRef<Stream>,
183) -> Result<Array> {
184 Array::try_from_op(|res| unsafe {
185 mlx_sys::mlx_partition_all(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
186 })
187}
188
189/// Returns the indices that partition the array. Returns an error if the arguments are invalid.
190///
191/// The ordering of the elements within a partition in given by the indices is undefined.
192///
193/// # Params
194///
195/// - `a`: The array to sort.
196/// - `kth`: element index at the `kth` position in the output will give the sorted position. All
197/// indices before the`kth` position will be of elements less than or equal to the element at the
198/// `kth` index and all indices after will be elemenents greater than or equal to the element at
199/// the `kth` position.
200/// - `axis`: axis to partition over
201///
202/// # Example
203///
204/// ```rust
205/// use mlx_rs::{Array, ops::*};
206///
207/// let a = Array::from_slice(&[3, 2, 1], &[3]);
208/// let kth = 1;
209/// let axis = 0;
210/// let result = argpartition(&a, kth, axis);
211/// ```
212#[generate_macro]
213#[default_device]
214pub fn argpartition_device(
215 a: impl AsRef<Array>,
216 kth: i32,
217 axis: i32,
218 #[optional] stream: impl AsRef<Stream>,
219) -> Result<Array> {
220 Array::try_from_op(|res| unsafe {
221 mlx_sys::mlx_argpartition(
222 res,
223 a.as_ref().as_ptr(),
224 kth,
225 axis,
226 stream.as_ref().as_ptr(),
227 )
228 })
229}
230
231/// Returns the indices that partition the flattened array. Returns an error if the arguments are
232/// invalid.
233///
234/// The ordering of the elements within a partition in given by the indices is undefined.
235///
236/// # Params
237///
238/// - `a`: The array to sort.
239/// - `kth`: element index at the `kth` position in the output will give the sorted position. All
240/// indices before the`kth` position will be of elements less than or equal to the element at the
241/// `kth` index and all indices after will be elemenents greater than or equal to the element at
242/// the `kth` position.
243///
244/// # Example
245///
246/// ```rust
247/// use mlx_rs::{Array, ops::*};
248///
249/// let a = Array::from_slice(&[3, 2, 1], &[3]);
250/// let kth = 1;
251/// let result = argpartition_all(&a, kth);
252/// ```
253#[generate_macro]
254#[default_device]
255pub fn argpartition_all_device(
256 a: impl AsRef<Array>,
257 kth: i32,
258 #[optional] stream: impl AsRef<Stream>,
259) -> Result<Array> {
260 Array::try_from_op(|res| unsafe {
261 mlx_sys::mlx_argpartition_all(res, a.as_ref().as_ptr(), kth, stream.as_ref().as_ptr())
262 })
263}
264
265#[cfg(test)]
266mod tests {
267 use crate::Array;
268
269 #[test]
270 fn test_sort_with_invalid_axis() {
271 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
272 let axis = 1;
273 let result = super::sort(&a, axis);
274 assert!(result.is_err());
275 }
276
277 #[test]
278 fn test_partition_with_invalid_axis() {
279 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
280 let kth = 2;
281 let axis = 1;
282 let result = super::partition(&a, kth, axis);
283 assert!(result.is_err());
284 }
285
286 #[test]
287 fn test_partition_with_invalid_kth() {
288 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
289 let kth = 5;
290 let axis = 0;
291 let result = super::partition(&a, kth, axis);
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn test_partition_all_with_invalid_kth() {
297 let a = Array::from_slice(&[1, 2, 3, 4, 5], &[5]);
298 let kth = 5;
299 let result = super::partition_all(&a, kth);
300 assert!(result.is_err());
301 }
302}