1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{sync::Arc, time::Duration};
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(untagged)]
13pub enum KeepAlive {
14 /// Keep model alive for N seconds
15 Seconds(isize),
16 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
17 Duration(String),
18}
19
20impl KeepAlive {
21 /// Keep model alive until a new model is loaded or until Ollama shuts down
22 fn indefinite() -> Self {
23 Self::Seconds(-1)
24 }
25}
26
27impl Default for KeepAlive {
28 fn default() -> Self {
29 Self::indefinite()
30 }
31}
32
33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
35pub struct Model {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: usize,
39 pub keep_alive: Option<KeepAlive>,
40 pub supports_tools: Option<bool>,
41}
42
43fn get_max_tokens(name: &str) -> usize {
44 /// Default context length for unknown models.
45 const DEFAULT_TOKENS: usize = 4096;
46 /// Magic number. Lets many Ollama models work with ~16GB of ram.
47 const MAXIMUM_TOKENS: usize = 16384;
48
49 match name.split(':').next().unwrap() {
50 "phi" | "tinyllama" | "granite-code" => 2048,
51 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
52 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
53 "codellama" | "starcoder2" => 16384,
54 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
55 | "dolphin-mixtral" => 32768,
56 "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
57 | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
58 | "devstral" => 128000,
59 _ => DEFAULT_TOKENS,
60 }
61 .clamp(1, MAXIMUM_TOKENS)
62}
63
64impl Model {
65 pub fn new(
66 name: &str,
67 display_name: Option<&str>,
68 max_tokens: Option<usize>,
69 supports_tools: Option<bool>,
70 ) -> Self {
71 Self {
72 name: name.to_owned(),
73 display_name: display_name
74 .map(ToString::to_string)
75 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
76 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
77 keep_alive: Some(KeepAlive::indefinite()),
78 supports_tools,
79 }
80 }
81
82 pub fn id(&self) -> &str {
83 &self.name
84 }
85
86 pub fn display_name(&self) -> &str {
87 self.display_name.as_ref().unwrap_or(&self.name)
88 }
89
90 pub fn max_token_count(&self) -> usize {
91 self.max_tokens
92 }
93}
94
95#[derive(Serialize, Deserialize, Debug)]
96#[serde(tag = "role", rename_all = "lowercase")]
97pub enum ChatMessage {
98 Assistant {
99 content: String,
100 tool_calls: Option<Vec<OllamaToolCall>>,
101 },
102 User {
103 content: String,
104 },
105 System {
106 content: String,
107 },
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111#[serde(rename_all = "lowercase")]
112pub enum OllamaToolCall {
113 Function(OllamaFunctionCall),
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117pub struct OllamaFunctionCall {
118 pub name: String,
119 pub arguments: Value,
120}
121
122#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
123pub struct OllamaFunctionTool {
124 pub name: String,
125 pub description: Option<String>,
126 pub parameters: Option<Value>,
127}
128
129#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
130#[serde(tag = "type", rename_all = "lowercase")]
131pub enum OllamaTool {
132 Function { function: OllamaFunctionTool },
133}
134
135#[derive(Serialize, Debug)]
136pub struct ChatRequest {
137 pub model: String,
138 pub messages: Vec<ChatMessage>,
139 pub stream: bool,
140 pub keep_alive: KeepAlive,
141 pub options: Option<ChatOptions>,
142 pub tools: Vec<OllamaTool>,
143}
144
145impl ChatRequest {
146 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
147 self.stream = false;
148 self.tools = tools;
149 self
150 }
151}
152
153// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
154#[derive(Serialize, Default, Debug)]
155pub struct ChatOptions {
156 pub num_ctx: Option<usize>,
157 pub num_predict: Option<isize>,
158 pub stop: Option<Vec<String>>,
159 pub temperature: Option<f32>,
160 pub top_p: Option<f32>,
161}
162
163#[derive(Deserialize, Debug)]
164pub struct ChatResponseDelta {
165 #[allow(unused)]
166 pub model: String,
167 #[allow(unused)]
168 pub created_at: String,
169 pub message: ChatMessage,
170 #[allow(unused)]
171 pub done_reason: Option<String>,
172 #[allow(unused)]
173 pub done: bool,
174}
175
176#[derive(Serialize, Deserialize)]
177pub struct LocalModelsResponse {
178 pub models: Vec<LocalModelListing>,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct LocalModelListing {
183 pub name: String,
184 pub modified_at: String,
185 pub size: u64,
186 pub digest: String,
187 pub details: ModelDetails,
188}
189
190#[derive(Serialize, Deserialize)]
191pub struct LocalModel {
192 pub modelfile: String,
193 pub parameters: String,
194 pub template: String,
195 pub details: ModelDetails,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct ModelDetails {
200 pub format: String,
201 pub family: String,
202 pub families: Option<Vec<String>>,
203 pub parameter_size: String,
204 pub quantization_level: String,
205}
206
207#[derive(Deserialize, Debug)]
208pub struct ModelShow {
209 #[serde(default)]
210 pub capabilities: Vec<String>,
211}
212
213impl ModelShow {
214 pub fn supports_tools(&self) -> bool {
215 // .contains expects &String, which would require an additional allocation
216 self.capabilities.iter().any(|v| v == "tools")
217 }
218}
219
220pub async fn complete(
221 client: &dyn HttpClient,
222 api_url: &str,
223 request: ChatRequest,
224) -> Result<ChatResponseDelta> {
225 let uri = format!("{api_url}/api/chat");
226 let request_builder = HttpRequest::builder()
227 .method(Method::POST)
228 .uri(uri)
229 .header("Content-Type", "application/json");
230
231 let serialized_request = serde_json::to_string(&request)?;
232 let request = request_builder.body(AsyncBody::from(serialized_request))?;
233
234 let mut response = client.send(request).await?;
235
236 let mut body = Vec::new();
237 response.body_mut().read_to_end(&mut body).await?;
238
239 if response.status().is_success() {
240 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
241 Ok(response_message)
242 } else {
243 let body_str = std::str::from_utf8(&body)?;
244 anyhow::bail!(
245 "Failed to connect to API: {} {}",
246 response.status(),
247 body_str
248 );
249 }
250}
251
252pub async fn stream_chat_completion(
253 client: &dyn HttpClient,
254 api_url: &str,
255 request: ChatRequest,
256) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
257 let uri = format!("{api_url}/api/chat");
258 let request_builder = http::Request::builder()
259 .method(Method::POST)
260 .uri(uri)
261 .header("Content-Type", "application/json");
262
263 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
264 let mut response = client.send(request).await?;
265 if response.status().is_success() {
266 let reader = BufReader::new(response.into_body());
267
268 Ok(reader
269 .lines()
270 .map(|line| match line {
271 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
272 Err(e) => Err(e.into()),
273 })
274 .boxed())
275 } else {
276 let mut body = String::new();
277 response.body_mut().read_to_string(&mut body).await?;
278 anyhow::bail!(
279 "Failed to connect to Ollama API: {} {}",
280 response.status(),
281 body,
282 );
283 }
284}
285
286pub async fn get_models(
287 client: &dyn HttpClient,
288 api_url: &str,
289 _: Option<Duration>,
290) -> Result<Vec<LocalModelListing>> {
291 let uri = format!("{api_url}/api/tags");
292 let request_builder = HttpRequest::builder()
293 .method(Method::GET)
294 .uri(uri)
295 .header("Accept", "application/json");
296
297 let request = request_builder.body(AsyncBody::default())?;
298
299 let mut response = client.send(request).await?;
300
301 let mut body = String::new();
302 response.body_mut().read_to_string(&mut body).await?;
303
304 anyhow::ensure!(
305 response.status().is_success(),
306 "Failed to connect to Ollama API: {} {}",
307 response.status(),
308 body,
309 );
310 let response: LocalModelsResponse =
311 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
312 Ok(response.models)
313}
314
315/// Fetch details of a model, used to determine model capabilities
316pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
317 let uri = format!("{api_url}/api/show");
318 let request = HttpRequest::builder()
319 .method(Method::POST)
320 .uri(uri)
321 .header("Content-Type", "application/json")
322 .body(AsyncBody::from(
323 serde_json::json!({ "model": model }).to_string(),
324 ))?;
325
326 let mut response = client.send(request).await?;
327 let mut body = String::new();
328 response.body_mut().read_to_string(&mut body).await?;
329
330 anyhow::ensure!(
331 response.status().is_success(),
332 "Failed to connect to Ollama API: {} {}",
333 response.status(),
334 body,
335 );
336 let details: ModelShow = serde_json::from_str(body.as_str())?;
337 Ok(details)
338}
339
340/// Sends an empty request to Ollama to trigger loading the model
341pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
342 let uri = format!("{api_url}/api/generate");
343 let request = HttpRequest::builder()
344 .method(Method::POST)
345 .uri(uri)
346 .header("Content-Type", "application/json")
347 .body(AsyncBody::from(
348 serde_json::json!({
349 "model": model,
350 "keep_alive": "15m",
351 })
352 .to_string(),
353 ))?;
354
355 let mut response = client.send(request).await?;
356
357 if response.status().is_success() {
358 Ok(())
359 } else {
360 let mut body = String::new();
361 response.body_mut().read_to_string(&mut body).await?;
362 anyhow::bail!(
363 "Failed to connect to Ollama API: {} {}",
364 response.status(),
365 body,
366 );
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn parse_completion() {
376 let response = serde_json::json!({
377 "model": "llama3.2",
378 "created_at": "2023-12-12T14:13:43.416799Z",
379 "message": {
380 "role": "assistant",
381 "content": "Hello! How are you today?"
382 },
383 "done": true,
384 "total_duration": 5191566416u64,
385 "load_duration": 2154458,
386 "prompt_eval_count": 26,
387 "prompt_eval_duration": 383809000,
388 "eval_count": 298,
389 "eval_duration": 4799921000u64
390 });
391 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
392 }
393
394 #[test]
395 fn parse_streaming_completion() {
396 let partial = serde_json::json!({
397 "model": "llama3.2",
398 "created_at": "2023-08-04T08:52:19.385406455-07:00",
399 "message": {
400 "role": "assistant",
401 "content": "The",
402 "images": null
403 },
404 "done": false
405 });
406
407 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
408
409 let last = serde_json::json!({
410 "model": "llama3.2",
411 "created_at": "2023-08-04T19:22:45.499127Z",
412 "message": {
413 "role": "assistant",
414 "content": ""
415 },
416 "done": true,
417 "total_duration": 4883583458u64,
418 "load_duration": 1334875,
419 "prompt_eval_count": 26,
420 "prompt_eval_duration": 342546000,
421 "eval_count": 282,
422 "eval_duration": 4535599000u64
423 });
424
425 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
426 }
427
428 #[test]
429 fn parse_tool_call() {
430 let response = serde_json::json!({
431 "model": "llama3.2:3b",
432 "created_at": "2025-04-28T20:02:02.140489Z",
433 "message": {
434 "role": "assistant",
435 "content": "",
436 "tool_calls": [
437 {
438 "function": {
439 "name": "weather",
440 "arguments": {
441 "city": "london",
442 }
443 }
444 }
445 ]
446 },
447 "done_reason": "stop",
448 "done": true,
449 "total_duration": 2758629166u64,
450 "load_duration": 1770059875,
451 "prompt_eval_count": 147,
452 "prompt_eval_duration": 684637583,
453 "eval_count": 16,
454 "eval_duration": 302561917,
455 });
456
457 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
458 match result.message {
459 ChatMessage::Assistant {
460 content,
461 tool_calls,
462 } => {
463 assert!(content.is_empty());
464 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
465 }
466 _ => panic!("Deserialized wrong role"),
467 }
468 }
469
470 #[test]
471 fn parse_show_model() {
472 let response = serde_json::json!({
473 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
474 "details": {
475 "parent_model": "",
476 "format": "gguf",
477 "family": "llama",
478 "families": ["llama"],
479 "parameter_size": "3.2B",
480 "quantization_level": "Q4_K_M"
481 },
482 "model_info": {
483 "general.architecture": "llama",
484 "general.basename": "Llama-3.2",
485 "general.file_type": 15,
486 "general.finetune": "Instruct",
487 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
488 "general.parameter_count": 3212749888u64,
489 "general.quantization_version": 2,
490 "general.size_label": "3B",
491 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
492 "general.type": "model",
493 "llama.attention.head_count": 24,
494 "llama.attention.head_count_kv": 8,
495 "llama.attention.key_length": 128,
496 "llama.attention.layer_norm_rms_epsilon": 0.00001,
497 "llama.attention.value_length": 128,
498 "llama.block_count": 28,
499 "llama.context_length": 131072,
500 "llama.embedding_length": 3072,
501 "llama.feed_forward_length": 8192,
502 "llama.rope.dimension_count": 128,
503 "llama.rope.freq_base": 500000,
504 "llama.vocab_size": 128256,
505 "tokenizer.ggml.bos_token_id": 128000,
506 "tokenizer.ggml.eos_token_id": 128009,
507 "tokenizer.ggml.merges": null,
508 "tokenizer.ggml.model": "gpt2",
509 "tokenizer.ggml.pre": "llama-bpe",
510 "tokenizer.ggml.token_type": null,
511 "tokenizer.ggml.tokens": null
512 },
513 "tensors": [
514 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
515 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
516 ],
517 "capabilities": ["completion", "tools"],
518 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
519 });
520
521 let result: ModelShow = serde_json::from_value(response).unwrap();
522 assert!(result.supports_tools());
523 assert!(result.capabilities.contains(&"tools".to_string()));
524 assert!(result.capabilities.contains(&"completion".to_string()));
525 }
526}