mnist/
data.rs

1// TODO
2
3use mlx_rs::{error::Exception, ops::stack, Array};
4use mnist::{Mnist, MnistBuilder};
5
6const IMAGE_SIZE: usize = 28 * 28;
7
8pub fn read_data() -> (Vec<Array>, Vec<u8>, Array, Array) {
9    let Mnist {
10        trn_img,
11        trn_lbl,
12        val_img: _,
13        val_lbl: _,
14        tst_img,
15        tst_lbl,
16    } = MnistBuilder::new()
17        .label_format_digit()
18        .base_path("data")
19        .training_images_filename("train-images.idx3-ubyte")
20        .training_labels_filename("train-labels.idx1-ubyte")
21        .test_images_filename("t10k-images.idx3-ubyte")
22        .test_labels_filename("t10k-labels.idx1-ubyte")
23        .finalize();
24
25    // Check size
26    assert_eq!(trn_img.len(), trn_lbl.len() * IMAGE_SIZE);
27    assert_eq!(tst_img.len(), tst_lbl.len() * IMAGE_SIZE);
28
29    // Convert to Array
30    let train_images = trn_img
31        .chunks_exact(IMAGE_SIZE)
32        .map(|chunk| Array::from_slice(chunk, &[IMAGE_SIZE as i32]))
33        .collect();
34
35    let test_images = tst_img
36        .chunks_exact(IMAGE_SIZE)
37        .map(|chunk| Array::from_slice(chunk, &[IMAGE_SIZE as i32]))
38        .collect::<Vec<_>>();
39    let test_images = stack(&test_images, 0).unwrap();
40
41    let test_labels = Array::from_slice(&tst_lbl, &[tst_lbl.len() as i32]);
42
43    (train_images, trn_lbl, test_images, test_labels)
44}
45
46/// The iterator is collected to avoid repeated calls to `stack` in the training loop.
47pub fn iterate_data<'a>(
48    images: &'a [Array],
49    labels: &'a [u8],
50    batch_size: usize,
51) -> Result<Vec<(Array, Array)>, Exception> {
52    images
53        .chunks_exact(batch_size)
54        .zip(labels.chunks_exact(batch_size))
55        .map(move |(images, labels)| {
56            let images = stack(images, 0)?;
57            let labels = Array::from_slice(labels, &[batch_size as i32]);
58            Ok((images, labels))
59        })
60        .collect()
61}