1use mlx_rs::{error::Exception, ops::stack, Array};
4use mnist::{Mnist, MnistBuilder};
5
6const IMAGE_SIZE: usize = 28 * 28;
7
8pub fn read_data() -> (Vec<Array>, Vec<u8>, Array, Array) {
9 let Mnist {
10 trn_img,
11 trn_lbl,
12 val_img: _,
13 val_lbl: _,
14 tst_img,
15 tst_lbl,
16 } = MnistBuilder::new()
17 .label_format_digit()
18 .base_path("data")
19 .training_images_filename("train-images.idx3-ubyte")
20 .training_labels_filename("train-labels.idx1-ubyte")
21 .test_images_filename("t10k-images.idx3-ubyte")
22 .test_labels_filename("t10k-labels.idx1-ubyte")
23 .finalize();
24
25 assert_eq!(trn_img.len(), trn_lbl.len() * IMAGE_SIZE);
27 assert_eq!(tst_img.len(), tst_lbl.len() * IMAGE_SIZE);
28
29 let train_images = trn_img
31 .chunks_exact(IMAGE_SIZE)
32 .map(|chunk| Array::from_slice(chunk, &[IMAGE_SIZE as i32]))
33 .collect();
34
35 let test_images = tst_img
36 .chunks_exact(IMAGE_SIZE)
37 .map(|chunk| Array::from_slice(chunk, &[IMAGE_SIZE as i32]))
38 .collect::<Vec<_>>();
39 let test_images = stack(&test_images, 0).unwrap();
40
41 let test_labels = Array::from_slice(&tst_lbl, &[tst_lbl.len() as i32]);
42
43 (train_images, trn_lbl, test_images, test_labels)
44}
45
46pub fn iterate_data<'a>(
48 images: &'a [Array],
49 labels: &'a [u8],
50 batch_size: usize,
51) -> Result<Vec<(Array, Array)>, Exception> {
52 images
53 .chunks_exact(batch_size)
54 .zip(labels.chunks_exact(batch_size))
55 .map(move |(images, labels)| {
56 let images = stack(images, 0)?;
57 let labels = Array::from_slice(labels, &[batch_size as i32]);
58 Ok((images, labels))
59 })
60 .collect()
61}