1use hf_hub::{
2 api::sync::{Api, ApiBuilder, ApiRepo},
3 Repo,
4};
5use mlx_rs::{
6 array,
7 module::{Module, ModuleParametersExt},
8 ops::indexing::{argmax, IndexOp, NewAxis},
9 random::categorical,
10 transforms::eval,
11 Array,
12};
13use tokenizers::Tokenizer;
14
15mod model;
16
17use model::{Mistral, MistralInput, MistralOutput, ModelArgs};
18
19type Error = Box<dyn std::error::Error + Send + Sync>;
20type Result<T, E = Error> = std::result::Result<T, E>;
21
22use clap::Parser;
23
24#[derive(Parser)]
25#[command(about = "Mistral inference example")]
26pub struct Cli {
27 #[clap(long, default_value = "In the begging the Unverse was created.")]
29 prompt: String,
30
31 #[clap(long, default_value = "100")]
33 max_tokens: usize,
34
35 #[clap(long, default_value = "0.0")]
37 temp: f32,
38
39 #[clap(long, default_value = "10")]
41 tokens_per_eval: usize,
42
43 #[clap(long, default_value = "0")]
45 seed: u64,
46}
47
48fn build_hf_api() -> Result<Api> {
49 let cache_dir = std::env::var("HF_CACHE_DIR").ok();
50
51 let mut builder = ApiBuilder::new();
52 if let Some(cache_dir) = cache_dir {
53 builder = builder.with_cache_dir(cache_dir.into());
54 }
55 builder.build().map_err(Into::into)
56}
57
58fn get_tokenizer(repo: &ApiRepo) -> Result<Tokenizer> {
59 let tokenizer_filename = repo.get("tokenizer.json")?;
60 let t = Tokenizer::from_file(tokenizer_filename)?;
61
62 Ok(t)
63}
64
65fn get_model_args(repo: &ApiRepo) -> Result<ModelArgs> {
66 let model_args_filename = repo.get("params.json")?;
67 let file = std::fs::File::open(model_args_filename)?;
68 let model_args: ModelArgs = serde_json::from_reader(file)?;
69
70 Ok(model_args)
71}
72
73fn load_model(repo: &ApiRepo) -> Result<Mistral> {
74 let model_args = get_model_args(repo)?;
75 let mut model = Mistral::new(&model_args)?;
76 let weights_filename = repo.get("weights.safetensors")?;
77 model.load_safetensors(weights_filename)?;
78
79 Ok(model)
80}
81
82fn sample(logits: &Array, temp: f32) -> Result<Array> {
83 match temp {
84 0.0 => argmax(logits, -1, None).map_err(Into::into),
85 _ => {
86 let logits = logits.multiply(array!(1.0 / temp))?;
87 categorical(logits, None, None, None).map_err(Into::into)
88 }
89 }
90}
91
92macro_rules! tri {
93 ($expr:expr) => {
94 match $expr {
95 Ok(val) => val,
96 Err(e) => return Some(Err(e.into())),
97 }
98 };
99}
100
101struct Generate<'a> {
102 model: &'a mut Mistral,
103 temp: f32,
104 state: GenerateState<'a>,
105}
106
107enum GenerateState<'a> {
108 Start {
109 prompt_token: &'a Array,
110 },
111 Continue {
112 y: Array,
113 cache: Vec<Option<(Array, Array)>>,
114 },
115}
116
117impl<'a> Generate<'a> {
118 pub fn new(model: &'a mut Mistral, prompt_token: &'a Array, temp: f32) -> Self {
119 Self {
120 model,
121 temp,
122 state: GenerateState::Start { prompt_token },
123 }
124 }
125}
126
127impl Iterator for Generate<'_> {
128 type Item = Result<Array>;
129
130 fn next(&mut self) -> Option<Self::Item> {
131 match &self.state {
132 GenerateState::Start { prompt_token } => {
133 let initial_cache = Vec::with_capacity(0); let input = MistralInput {
135 inputs: prompt_token,
136 cache: &initial_cache,
137 };
138 let MistralOutput { logits, cache } = tri!(self.model.forward(input));
139 let y = tri!(sample(&logits.index((.., -1, ..)), self.temp));
140
141 self.state = GenerateState::Continue {
142 y: y.clone(),
143 cache,
144 };
145
146 Some(Ok(y))
147 }
148 GenerateState::Continue { y, cache } => {
149 let next_token = y.index((.., NewAxis));
150 let input = MistralInput {
151 inputs: &next_token,
152 cache: cache.as_slice(),
153 };
154 let MistralOutput {
155 logits,
156 cache: new_cache,
157 } = tri!(self.model.forward(input));
158
159 let logits = tri!(logits.squeeze(&[1]));
160 let y = tri!(sample(&logits, self.temp));
161
162 self.state = GenerateState::Continue {
163 y: y.clone(),
164 cache: new_cache,
165 };
166
167 Some(Ok(y))
168 }
169 }
170 }
171}
172
173fn main() -> Result<()> {
174 let _ = dotenv::dotenv();
178 let api = build_hf_api()?;
179
180 let cli = Cli::parse();
182
183 mlx_rs::random::seed(cli.seed)?;
184
185 let model_id = "minghuaw/Mistral-7B-v0.1".to_string();
188 let repo = api.repo(Repo::new(model_id, hf_hub::RepoType::Model));
189 println!("[INFO] Loading model... ");
190 let tokenizer = get_tokenizer(&repo)?;
191 let mut model = load_model(&repo)?;
192
193 model = mlx_rs::nn::quantize(model, None, None)?;
194
195 let encoding = tokenizer.encode(&cli.prompt[..], true)?;
196 let prompt_tokens = Array::from(encoding.get_ids()).index(NewAxis);
197 print!("{}", cli.prompt);
198
199 let generate = Generate::new(&mut model, &prompt_tokens, cli.temp);
200 let mut tokens = Vec::with_capacity(cli.max_tokens);
201 for (token, ntoks) in generate.zip(0..cli.max_tokens) {
202 let token = token?;
203 tokens.push(token);
204
205 if ntoks == 0 {
206 eval(&tokens)?;
207 }
208
209 if tokens.len() % cli.tokens_per_eval == 0 {
210 eval(&tokens)?;
211 let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
212 let s = tokenizer.decode(&slice, true)?;
213 print!("{}", s);
214 }
215 }
216
217 eval(&tokens)?;
218 let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
219 let s = tokenizer.decode(&slice, true)?;
220 println!("{}", s);
221
222 println!("------");
223
224 Ok(())
225}