1use mlx_rs::{
2 builder::Builder,
3 error::Exception,
4 losses::{CrossEntropyBuilder, LossReduction},
5 module::{Module, ModuleParameters},
6 nn,
7 ops::{eq, indexing::argmax, mean},
8 optimizers::{Optimizer, Sgd},
9 transforms::eval_params,
10 Array,
11};
12
13mod mlp;
15
16mod data;
18
19fn eval_fn(model: &mut mlp::Mlp, (x, y): (&Array, &Array)) -> Result<Array, Exception> {
20 let y_pred = model.forward(x)?;
21 let accuracy = mean(&eq(&argmax(&y_pred, 1, None)?, y)?, None, None)?;
22 Ok(accuracy)
23}
24
25fn main() -> Result<(), Box<dyn std::error::Error>> {
26 let num_layers = 2;
27 let hidden_dim = 32;
28 let num_classes = 10;
29 let batch_size = 256;
30 let num_epochs = 10;
31 let learning_rate = 1e-2;
32
33 let (train_images, train_labels, test_images, test_labels) = data::read_data();
34 let loader = data::iterate_data(&train_images, &train_labels, batch_size)?;
35
36 let input_dim = train_images[0].shape()[0];
37 let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, num_classes)?;
38
39 let cross_entropy = CrossEntropyBuilder::new()
40 .reduction(LossReduction::Mean)
41 .build()?;
42 let loss_fn = |model: &mut mlp::Mlp, (x, y): (&Array, &Array)| -> Result<Array, Exception> {
43 let y_pred = model.forward(x)?;
44 cross_entropy.apply(y_pred, y)
45 };
46 let mut loss_and_grad_fn = nn::value_and_grad(loss_fn);
47
48 let mut optimizer = Sgd::new(learning_rate);
49
50 for e in 0..num_epochs {
51 let now = std::time::Instant::now();
52 for (x, y) in &loader {
53 let (_loss, grad) = loss_and_grad_fn(&mut model, (x, y))?;
54 optimizer.update(&mut model, grad).unwrap();
55 eval_params(model.parameters())?;
56 }
57
58 let accuracy = eval_fn(&mut model, (&test_images, &test_labels))?;
60 let elapsed = now.elapsed();
61 println!(
62 "Epoch: {}, Test accuracy: {:.2}, Time: {:.2} s",
63 e,
64 accuracy.item::<f32>(),
65 elapsed.as_secs_f32()
66 );
67 }
68
69 Ok(())
70}