use guard::Guarded;
use mlx_sys::mlx_vector_array;
use crate::error::set_closure_error;
use crate::module::ModuleParameters;
use crate::{complex64, error::Exception, Array, FromNested};
use std::collections::HashMap;
use std::{marker::PhantomData, rc::Rc};
pub(crate) const SUCCESS: i32 = 0;
pub(crate) const FAILURE: i32 = 1;
pub(crate) mod guard;
pub(crate) mod io;
pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
if index < 0 {
len.saturating_add(index)
} else {
index
}
}
pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
if index.is_negative() {
(len as i32 + index) as usize
} else {
index as usize
}
}
pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
match axes.into_option() {
Some(axes) => axes.to_vec(),
None => {
let axes: Vec<i32> = (0..ndim).collect();
axes
}
}
}
pub(crate) struct VectorArray {
c_vec: mlx_sys::mlx_vector_array,
}
impl VectorArray {
pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
self.c_vec
}
pub(crate) fn try_from_iter(
iter: impl Iterator<Item = impl AsRef<Array>>,
) -> Result<Self, Exception> {
VectorArray::try_from_op(|res| unsafe {
let mut status = SUCCESS;
for arr in iter {
status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
if status != SUCCESS {
return status;
}
}
status
})
}
pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
where
T: FromIterator<Array>,
{
unsafe {
let size = mlx_sys::mlx_vector_array_size(self.c_vec);
(0..size)
.map(|i| {
Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
})
.collect::<Result<T, Exception>>()
}
}
}
impl Drop for VectorArray {
fn drop(&mut self) {
let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
debug_assert_eq!(status, SUCCESS);
}
}
pub trait IntoOption<T> {
fn into_option(self) -> Option<T>;
}
impl<T> IntoOption<T> for Option<T> {
fn into_option(self) -> Option<T> {
self
}
}
impl<T> IntoOption<T> for T {
fn into_option(self) -> Option<T> {
Some(self)
}
}
impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
fn into_option(self) -> Option<&'a [T]> {
Some(self)
}
}
impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
fn into_option(self) -> Option<&'a [T]> {
Some(self)
}
}
pub trait ScalarOrArray<'a> {
type Array: AsRef<Array> + 'a;
fn into_owned_or_ref_array(self) -> Self::Array;
}
impl ScalarOrArray<'_> for Array {
type Array = Array;
fn into_owned_or_ref_array(self) -> Array {
self
}
}
impl<'a> ScalarOrArray<'a> for &'a Array {
type Array = &'a Array;
fn into_owned_or_ref_array(self) -> &'a Array {
self
}
}
impl ScalarOrArray<'static> for bool {
type Array = Array;
fn into_owned_or_ref_array(self) -> Array {
Array::from_bool(self)
}
}
impl ScalarOrArray<'static> for i32 {
type Array = Array;
fn into_owned_or_ref_array(self) -> Array {
Array::from_int(self)
}
}
impl ScalarOrArray<'static> for f32 {
type Array = Array;
fn into_owned_or_ref_array(self) -> Array {
Array::from_float(self)
}
}
impl ScalarOrArray<'static> for complex64 {
type Array = Array;
fn into_owned_or_ref_array(self) -> Array {
Array::from_complex(self)
}
}
impl<T> ScalarOrArray<'static> for T
where
Array: FromNested<T>,
{
type Array = Array;
fn into_owned_or_ref_array(self) -> Array {
Array::from_nested(self)
}
}
#[derive(Debug)]
pub(crate) struct Closure<'a> {
c_closure: mlx_sys::mlx_closure,
lt_marker: PhantomData<&'a ()>,
}
impl<'a> Closure<'a> {
pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
self.c_closure
}
pub(crate) fn new<F>(closure: F) -> Self
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
{
let c_closure = new_mlx_closure(closure);
Self {
c_closure,
lt_marker: PhantomData,
}
}
pub(crate) fn new_fallible<F>(closure: F) -> Self
where
F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
{
let c_closure = new_mlx_fallible_closure(closure);
Self {
c_closure,
lt_marker: PhantomData,
}
}
}
impl Drop for Closure<'_> {
fn drop(&mut self) {
let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
debug_assert_eq!(status, SUCCESS);
}
}
fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
{
let boxed = Box::new(closure);
let raw = Box::into_raw(boxed);
let payload = raw as *mut std::ffi::c_void;
unsafe {
mlx_sys::mlx_closure_new_func_payload(Some(trampoline::<F>), payload, Some(noop_dtor))
}
}
fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
where
F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
{
let boxed = Box::new(closure);
let raw = Box::into_raw(boxed);
let payload = raw as *mut std::ffi::c_void;
unsafe {
mlx_sys::mlx_closure_new_func_payload(
Some(trampoline_fallible::<F>),
payload,
Some(noop_dtor),
)
}
}
fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
unsafe {
let result = mlx_sys::mlx_vector_array_new();
let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
result
}
}
fn mlx_vector_array_values(
vector_array: mlx_sys::mlx_vector_array,
) -> Result<Vec<Array>, Exception> {
unsafe {
let size = mlx_sys::mlx_vector_array_size(vector_array);
(0..size)
.map(|index| {
Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
})
.collect()
}
}
extern "C" fn trampoline<'a, F>(
ret: *mut mlx_vector_array,
vector_array: mlx_vector_array,
payload: *mut std::ffi::c_void,
) -> i32
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
{
unsafe {
let raw_closure: *mut F = payload as *mut _;
let mut closure = Box::from_raw(raw_closure);
let arrays = match mlx_vector_array_values(vector_array) {
Ok(arrays) => arrays,
Err(_) => {
return FAILURE;
}
};
let result = closure(&arrays);
*ret = new_mlx_vector_array(result);
SUCCESS
}
}
extern "C" fn trampoline_fallible<'a, F>(
ret: *mut mlx_vector_array,
vector_array: mlx_vector_array,
payload: *mut std::ffi::c_void,
) -> i32
where
F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
{
unsafe {
let raw_closure: *mut F = payload as *mut _;
let mut closure = Box::from_raw(raw_closure);
let arrays = match mlx_vector_array_values(vector_array) {
Ok(arrays) => arrays,
Err(e) => {
set_closure_error(e);
return FAILURE;
}
};
let result = closure(&arrays);
match result {
Ok(result) => {
*ret = new_mlx_vector_array(result);
SUCCESS
}
Err(err) => {
set_closure_error(err);
FAILURE
}
}
}
}
extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
pub(crate) fn get_mut_or_insert_with<'a, T>(
map: &'a mut HashMap<Rc<str>, T>,
key: &Rc<str>,
f: impl FnOnce() -> T,
) -> &'a mut T {
if !map.contains_key(key) {
map.insert(key.clone(), f());
}
map.get_mut(key).unwrap()
}
pub trait Updatable {
fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
}
impl<T> Updatable for T
where
T: ModuleParameters,
{
fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
use itertools::Itertools;
self.parameters()
.flatten()
.into_iter()
.sorted_by(|a, b| a.0.cmp(&b.0))
.map(|(_, v)| v)
}
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
use itertools::Itertools;
self.parameters_mut()
.flatten()
.into_iter()
.sorted_by(|a, b| a.0.cmp(&b.0))
.map(|(_, v)| v)
}
}
impl<T1, T2> Updatable for (T1, T2)
where
T1: Updatable,
T2: Updatable,
{
fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
let (a, b) = self;
let params = a.updatable_states();
params.into_iter().chain(b.updatable_states())
}
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
let (a, b) = self;
let params = a.updatable_states_mut();
params.into_iter().chain(b.updatable_states_mut())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SingleOrPair<T = i32> {
Single(T),
Pair(T, T),
}
impl<T: Clone> SingleOrPair<T> {
pub fn first(&self) -> T {
match self {
SingleOrPair::Single(v) => v.clone(),
SingleOrPair::Pair(v1, _) => v1.clone(),
}
}
pub fn second(&self) -> T {
match self {
SingleOrPair::Single(v) => v.clone(),
SingleOrPair::Pair(_, v2) => v2.clone(),
}
}
}
impl<T> From<T> for SingleOrPair<T> {
fn from(value: T) -> Self {
SingleOrPair::Single(value)
}
}
impl<T> From<(T, T)> for SingleOrPair<T> {
fn from(value: (T, T)) -> Self {
SingleOrPair::Pair(value.0, value.1)
}
}
impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
fn from(value: SingleOrPair<T>) -> Self {
match value {
SingleOrPair::Single(v) => (v.clone(), v),
SingleOrPair::Pair(v1, v2) => (v1, v2),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SingleOrTriple<T = i32> {
Single(T),
Triple(T, T, T),
}
impl<T: Clone> SingleOrTriple<T> {
pub fn first(&self) -> T {
match self {
SingleOrTriple::Single(v) => v.clone(),
SingleOrTriple::Triple(v1, _, _) => v1.clone(),
}
}
pub fn second(&self) -> T {
match self {
SingleOrTriple::Single(v) => v.clone(),
SingleOrTriple::Triple(_, v2, _) => v2.clone(),
}
}
pub fn third(&self) -> T {
match self {
SingleOrTriple::Single(v) => v.clone(),
SingleOrTriple::Triple(_, _, v3) => v3.clone(),
}
}
}
impl<T> From<T> for SingleOrTriple<T> {
fn from(value: T) -> Self {
SingleOrTriple::Single(value)
}
}
impl<T> From<(T, T, T)> for SingleOrTriple<T> {
fn from(value: (T, T, T)) -> Self {
SingleOrTriple::Triple(value.0, value.1, value.2)
}
}
impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
fn from(value: SingleOrTriple<T>) -> Self {
match value {
SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SingleOrVec<T> {
Single(T),
Vec(Vec<T>),
}
impl<T> From<T> for SingleOrVec<T> {
fn from(value: T) -> Self {
SingleOrVec::Single(value)
}
}
impl<T> From<Vec<T>> for SingleOrVec<T> {
fn from(value: Vec<T>) -> Self {
SingleOrVec::Vec(value)
}
}