mistral/
main.rs

1use hf_hub::{
2    api::sync::{Api, ApiBuilder, ApiRepo},
3    Repo,
4};
5use mlx_rs::{
6    array,
7    module::{Module, ModuleParametersExt},
8    ops::indexing::{argmax, IndexOp, NewAxis},
9    random::categorical,
10    transforms::eval,
11    Array,
12};
13use tokenizers::Tokenizer;
14
15mod model;
16
17use model::{Mistral, MistralInput, MistralOutput, ModelArgs};
18
19type Error = Box<dyn std::error::Error + Send + Sync>;
20type Result<T, E = Error> = std::result::Result<T, E>;
21
22use clap::Parser;
23
24#[derive(Parser)]
25#[command(about = "Mistral inference example")]
26pub struct Cli {
27    /// The message to be processed by the model
28    #[clap(long, default_value = "In the begging the Unverse was created.")]
29    prompt: String,
30
31    /// Maximum number of tokens to generate
32    #[clap(long, default_value = "100")]
33    max_tokens: usize,
34
35    /// The sampling temperature
36    #[clap(long, default_value = "0.0")]
37    temp: f32,
38
39    /// The batch size of tokens to generate
40    #[clap(long, default_value = "10")]
41    tokens_per_eval: usize,
42
43    /// The PRNG seed
44    #[clap(long, default_value = "0")]
45    seed: u64,
46}
47
48fn build_hf_api() -> Result<Api> {
49    let cache_dir = std::env::var("HF_CACHE_DIR").ok();
50
51    let mut builder = ApiBuilder::new();
52    if let Some(cache_dir) = cache_dir {
53        builder = builder.with_cache_dir(cache_dir.into());
54    }
55    builder.build().map_err(Into::into)
56}
57
58fn get_tokenizer(repo: &ApiRepo) -> Result<Tokenizer> {
59    let tokenizer_filename = repo.get("tokenizer.json")?;
60    let t = Tokenizer::from_file(tokenizer_filename)?;
61
62    Ok(t)
63}
64
65fn get_model_args(repo: &ApiRepo) -> Result<ModelArgs> {
66    let model_args_filename = repo.get("params.json")?;
67    let file = std::fs::File::open(model_args_filename)?;
68    let model_args: ModelArgs = serde_json::from_reader(file)?;
69
70    Ok(model_args)
71}
72
73fn load_model(repo: &ApiRepo) -> Result<Mistral> {
74    let model_args = get_model_args(repo)?;
75    let mut model = Mistral::new(&model_args)?;
76    let weights_filename = repo.get("weights.safetensors")?;
77    model.load_safetensors(weights_filename)?;
78
79    Ok(model)
80}
81
82fn sample(logits: &Array, temp: f32) -> Result<Array> {
83    match temp {
84        0.0 => argmax(logits, -1, None).map_err(Into::into),
85        _ => {
86            let logits = logits.multiply(array!(1.0 / temp))?;
87            categorical(logits, None, None, None).map_err(Into::into)
88        }
89    }
90}
91
92macro_rules! tri {
93    ($expr:expr) => {
94        match $expr {
95            Ok(val) => val,
96            Err(e) => return Some(Err(e.into())),
97        }
98    };
99}
100
101struct Generate<'a> {
102    model: &'a mut Mistral,
103    temp: f32,
104    state: GenerateState<'a>,
105}
106
107enum GenerateState<'a> {
108    Start {
109        prompt_token: &'a Array,
110    },
111    Continue {
112        y: Array,
113        cache: Vec<Option<(Array, Array)>>,
114    },
115}
116
117impl<'a> Generate<'a> {
118    pub fn new(model: &'a mut Mistral, prompt_token: &'a Array, temp: f32) -> Self {
119        Self {
120            model,
121            temp,
122            state: GenerateState::Start { prompt_token },
123        }
124    }
125}
126
127impl Iterator for Generate<'_> {
128    type Item = Result<Array>;
129
130    fn next(&mut self) -> Option<Self::Item> {
131        match &self.state {
132            GenerateState::Start { prompt_token } => {
133                let initial_cache = Vec::with_capacity(0); // This won't allocate
134                let input = MistralInput {
135                    inputs: prompt_token,
136                    cache: &initial_cache,
137                };
138                let MistralOutput { logits, cache } = tri!(self.model.forward(input));
139                let y = tri!(sample(&logits.index((.., -1, ..)), self.temp));
140
141                self.state = GenerateState::Continue {
142                    y: y.clone(),
143                    cache,
144                };
145
146                Some(Ok(y))
147            }
148            GenerateState::Continue { y, cache } => {
149                let next_token = y.index((.., NewAxis));
150                let input = MistralInput {
151                    inputs: &next_token,
152                    cache: cache.as_slice(),
153                };
154                let MistralOutput {
155                    logits,
156                    cache: new_cache,
157                } = tri!(self.model.forward(input));
158
159                let logits = tri!(logits.squeeze(&[1]));
160                let y = tri!(sample(&logits, self.temp));
161
162                self.state = GenerateState::Continue {
163                    y: y.clone(),
164                    cache: new_cache,
165                };
166
167                Some(Ok(y))
168            }
169        }
170    }
171}
172
173fn main() -> Result<()> {
174    // If you want to manually set the cache directory, you can set the HF_CACHE_DIR
175    // environment variable or put it in a .env file located at the root of this example
176    // (ie. examples/mistral/.env)
177    let _ = dotenv::dotenv();
178    let api = build_hf_api()?;
179
180    // Parse args
181    let cli = Cli::parse();
182
183    mlx_rs::random::seed(cli.seed)?;
184
185    // The model used in the original example is converted to safetensors and
186    // uploaded to the huggingface hub
187    let model_id = "minghuaw/Mistral-7B-v0.1".to_string();
188    let repo = api.repo(Repo::new(model_id, hf_hub::RepoType::Model));
189    println!("[INFO] Loading model... ");
190    let tokenizer = get_tokenizer(&repo)?;
191    let mut model = load_model(&repo)?;
192
193    model = mlx_rs::nn::quantize(model, None, None)?;
194
195    let encoding = tokenizer.encode(&cli.prompt[..], true)?;
196    let prompt_tokens = Array::from(encoding.get_ids()).index(NewAxis);
197    print!("{}", cli.prompt);
198
199    let generate = Generate::new(&mut model, &prompt_tokens, cli.temp);
200    let mut tokens = Vec::with_capacity(cli.max_tokens);
201    for (token, ntoks) in generate.zip(0..cli.max_tokens) {
202        let token = token?;
203        tokens.push(token);
204
205        if ntoks == 0 {
206            eval(&tokens)?;
207        }
208
209        if tokens.len() % cli.tokens_per_eval == 0 {
210            eval(&tokens)?;
211            let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
212            let s = tokenizer.decode(&slice, true)?;
213            print!("{}", s);
214        }
215    }
216
217    eval(&tokens)?;
218    let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
219    let s = tokenizer.decode(&slice, true)?;
220    println!("{}", s);
221
222    println!("------");
223
224    Ok(())
225}