mnist/
main.rs

1use mlx_rs::{
2    builder::Builder,
3    error::Exception,
4    losses::{CrossEntropyBuilder, LossReduction},
5    module::{Module, ModuleParameters},
6    nn,
7    ops::{eq, indexing::argmax, mean},
8    optimizers::{Optimizer, Sgd},
9    transforms::eval_params,
10    Array,
11};
12
13/// MLP model
14mod mlp;
15
16/// Retrieves MNIST dataset
17mod data;
18
19fn eval_fn(model: &mut mlp::Mlp, (x, y): (&Array, &Array)) -> Result<Array, Exception> {
20    let y_pred = model.forward(x)?;
21    let accuracy = mean(&eq(&argmax(&y_pred, 1, None)?, y)?, None, None)?;
22    Ok(accuracy)
23}
24
25fn main() -> Result<(), Box<dyn std::error::Error>> {
26    let num_layers = 2;
27    let hidden_dim = 32;
28    let num_classes = 10;
29    let batch_size = 256;
30    let num_epochs = 10;
31    let learning_rate = 1e-2;
32
33    let (train_images, train_labels, test_images, test_labels) = data::read_data();
34    let loader = data::iterate_data(&train_images, &train_labels, batch_size)?;
35
36    let input_dim = train_images[0].shape()[0];
37    let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, num_classes)?;
38
39    let cross_entropy = CrossEntropyBuilder::new()
40        .reduction(LossReduction::Mean)
41        .build()?;
42    let loss_fn = |model: &mut mlp::Mlp, (x, y): (&Array, &Array)| -> Result<Array, Exception> {
43        let y_pred = model.forward(x)?;
44        cross_entropy.apply(y_pred, y)
45    };
46    let mut loss_and_grad_fn = nn::value_and_grad(loss_fn);
47
48    let mut optimizer = Sgd::new(learning_rate);
49
50    for e in 0..num_epochs {
51        let now = std::time::Instant::now();
52        for (x, y) in &loader {
53            let (_loss, grad) = loss_and_grad_fn(&mut model, (x, y))?;
54            optimizer.update(&mut model, grad).unwrap();
55            eval_params(model.parameters())?;
56        }
57
58        // Evaluate on test set
59        let accuracy = eval_fn(&mut model, (&test_images, &test_labels))?;
60        let elapsed = now.elapsed();
61        println!(
62            "Epoch: {}, Test accuracy: {:.2}, Time: {:.2} s",
63            e,
64            accuracy.item::<f32>(),
65            elapsed.as_secs_f32()
66        );
67    }
68
69    Ok(())
70}