1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{sync::Arc, time::Duration};
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(untagged)]
13pub enum KeepAlive {
14 /// Keep model alive for N seconds
15 Seconds(isize),
16 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
17 Duration(String),
18}
19
20impl KeepAlive {
21 /// Keep model alive until a new model is loaded or until Ollama shuts down
22 fn indefinite() -> Self {
23 Self::Seconds(-1)
24 }
25}
26
27impl Default for KeepAlive {
28 fn default() -> Self {
29 Self::indefinite()
30 }
31}
32
33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
35pub struct Model {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: usize,
39 pub keep_alive: Option<KeepAlive>,
40 pub supports_tools: Option<bool>,
41 pub supports_vision: Option<bool>,
42 pub supports_thinking: Option<bool>,
43}
44
45fn get_max_tokens(name: &str) -> usize {
46 /// Default context length for unknown models.
47 const DEFAULT_TOKENS: usize = 4096;
48 /// Magic number. Lets many Ollama models work with ~16GB of ram.
49 const MAXIMUM_TOKENS: usize = 16384;
50
51 match name.split(':').next().unwrap() {
52 "phi" | "tinyllama" | "granite-code" => 2048,
53 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
54 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
55 "codellama" | "starcoder2" => 16384,
56 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
57 | "dolphin-mixtral" => 32768,
58 "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
59 | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
60 | "devstral" => 128000,
61 _ => DEFAULT_TOKENS,
62 }
63 .clamp(1, MAXIMUM_TOKENS)
64}
65
66impl Model {
67 pub fn new(
68 name: &str,
69 display_name: Option<&str>,
70 max_tokens: Option<usize>,
71 supports_tools: Option<bool>,
72 supports_vision: Option<bool>,
73 supports_thinking: Option<bool>,
74 ) -> Self {
75 Self {
76 name: name.to_owned(),
77 display_name: display_name
78 .map(ToString::to_string)
79 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
80 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
81 keep_alive: Some(KeepAlive::indefinite()),
82 supports_tools,
83 supports_vision,
84 supports_thinking,
85 }
86 }
87
88 pub fn id(&self) -> &str {
89 &self.name
90 }
91
92 pub fn display_name(&self) -> &str {
93 self.display_name.as_ref().unwrap_or(&self.name)
94 }
95
96 pub fn max_token_count(&self) -> usize {
97 self.max_tokens
98 }
99}
100
101#[derive(Serialize, Deserialize, Debug)]
102#[serde(tag = "role", rename_all = "lowercase")]
103pub enum ChatMessage {
104 Assistant {
105 content: String,
106 tool_calls: Option<Vec<OllamaToolCall>>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 images: Option<Vec<String>>,
109 thinking: Option<String>,
110 },
111 User {
112 content: String,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 images: Option<Vec<String>>,
115 },
116 System {
117 content: String,
118 },
119}
120
121#[derive(Serialize, Deserialize, Debug)]
122#[serde(rename_all = "lowercase")]
123pub enum OllamaToolCall {
124 Function(OllamaFunctionCall),
125}
126
127#[derive(Serialize, Deserialize, Debug)]
128pub struct OllamaFunctionCall {
129 pub name: String,
130 pub arguments: Value,
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134pub struct OllamaFunctionTool {
135 pub name: String,
136 pub description: Option<String>,
137 pub parameters: Option<Value>,
138}
139
140#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
141#[serde(tag = "type", rename_all = "lowercase")]
142pub enum OllamaTool {
143 Function { function: OllamaFunctionTool },
144}
145
146#[derive(Serialize, Debug)]
147pub struct ChatRequest {
148 pub model: String,
149 pub messages: Vec<ChatMessage>,
150 pub stream: bool,
151 pub keep_alive: KeepAlive,
152 pub options: Option<ChatOptions>,
153 pub tools: Vec<OllamaTool>,
154 pub think: Option<bool>,
155}
156
157impl ChatRequest {
158 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
159 self.stream = false;
160 self.tools = tools;
161 self
162 }
163}
164
165// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
166#[derive(Serialize, Default, Debug)]
167pub struct ChatOptions {
168 pub num_ctx: Option<usize>,
169 pub num_predict: Option<isize>,
170 pub stop: Option<Vec<String>>,
171 pub temperature: Option<f32>,
172 pub top_p: Option<f32>,
173}
174
175#[derive(Deserialize, Debug)]
176pub struct ChatResponseDelta {
177 #[allow(unused)]
178 pub model: String,
179 #[allow(unused)]
180 pub created_at: String,
181 pub message: ChatMessage,
182 #[allow(unused)]
183 pub done_reason: Option<String>,
184 #[allow(unused)]
185 pub done: bool,
186}
187
188#[derive(Serialize, Deserialize)]
189pub struct LocalModelsResponse {
190 pub models: Vec<LocalModelListing>,
191}
192
193#[derive(Serialize, Deserialize)]
194pub struct LocalModelListing {
195 pub name: String,
196 pub modified_at: String,
197 pub size: u64,
198 pub digest: String,
199 pub details: ModelDetails,
200}
201
202#[derive(Serialize, Deserialize)]
203pub struct LocalModel {
204 pub modelfile: String,
205 pub parameters: String,
206 pub template: String,
207 pub details: ModelDetails,
208}
209
210#[derive(Serialize, Deserialize)]
211pub struct ModelDetails {
212 pub format: String,
213 pub family: String,
214 pub families: Option<Vec<String>>,
215 pub parameter_size: String,
216 pub quantization_level: String,
217}
218
219#[derive(Deserialize, Debug)]
220pub struct ModelShow {
221 #[serde(default)]
222 pub capabilities: Vec<String>,
223}
224
225impl ModelShow {
226 pub fn supports_tools(&self) -> bool {
227 // .contains expects &String, which would require an additional allocation
228 self.capabilities.iter().any(|v| v == "tools")
229 }
230
231 pub fn supports_vision(&self) -> bool {
232 self.capabilities.iter().any(|v| v == "vision")
233 }
234
235 pub fn supports_thinking(&self) -> bool {
236 self.capabilities.iter().any(|v| v == "thinking")
237 }
238}
239
240pub async fn complete(
241 client: &dyn HttpClient,
242 api_url: &str,
243 request: ChatRequest,
244) -> Result<ChatResponseDelta> {
245 let uri = format!("{api_url}/api/chat");
246 let request_builder = HttpRequest::builder()
247 .method(Method::POST)
248 .uri(uri)
249 .header("Content-Type", "application/json");
250
251 let serialized_request = serde_json::to_string(&request)?;
252 let request = request_builder.body(AsyncBody::from(serialized_request))?;
253
254 let mut response = client.send(request).await?;
255
256 let mut body = Vec::new();
257 response.body_mut().read_to_end(&mut body).await?;
258
259 if response.status().is_success() {
260 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
261 Ok(response_message)
262 } else {
263 let body_str = std::str::from_utf8(&body)?;
264 anyhow::bail!(
265 "Failed to connect to API: {} {}",
266 response.status(),
267 body_str
268 );
269 }
270}
271
272pub async fn stream_chat_completion(
273 client: &dyn HttpClient,
274 api_url: &str,
275 request: ChatRequest,
276) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
277 let uri = format!("{api_url}/api/chat");
278 let request_builder = http::Request::builder()
279 .method(Method::POST)
280 .uri(uri)
281 .header("Content-Type", "application/json");
282
283 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
284 let mut response = client.send(request).await?;
285 if response.status().is_success() {
286 let reader = BufReader::new(response.into_body());
287
288 Ok(reader
289 .lines()
290 .map(|line| match line {
291 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
292 Err(e) => Err(e.into()),
293 })
294 .boxed())
295 } else {
296 let mut body = String::new();
297 response.body_mut().read_to_string(&mut body).await?;
298 anyhow::bail!(
299 "Failed to connect to Ollama API: {} {}",
300 response.status(),
301 body,
302 );
303 }
304}
305
306pub async fn get_models(
307 client: &dyn HttpClient,
308 api_url: &str,
309 _: Option<Duration>,
310) -> Result<Vec<LocalModelListing>> {
311 let uri = format!("{api_url}/api/tags");
312 let request_builder = HttpRequest::builder()
313 .method(Method::GET)
314 .uri(uri)
315 .header("Accept", "application/json");
316
317 let request = request_builder.body(AsyncBody::default())?;
318
319 let mut response = client.send(request).await?;
320
321 let mut body = String::new();
322 response.body_mut().read_to_string(&mut body).await?;
323
324 anyhow::ensure!(
325 response.status().is_success(),
326 "Failed to connect to Ollama API: {} {}",
327 response.status(),
328 body,
329 );
330 let response: LocalModelsResponse =
331 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
332 Ok(response.models)
333}
334
335/// Fetch details of a model, used to determine model capabilities
336pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
337 let uri = format!("{api_url}/api/show");
338 let request = HttpRequest::builder()
339 .method(Method::POST)
340 .uri(uri)
341 .header("Content-Type", "application/json")
342 .body(AsyncBody::from(
343 serde_json::json!({ "model": model }).to_string(),
344 ))?;
345
346 let mut response = client.send(request).await?;
347 let mut body = String::new();
348 response.body_mut().read_to_string(&mut body).await?;
349
350 anyhow::ensure!(
351 response.status().is_success(),
352 "Failed to connect to Ollama API: {} {}",
353 response.status(),
354 body,
355 );
356 let details: ModelShow = serde_json::from_str(body.as_str())?;
357 Ok(details)
358}
359
360/// Sends an empty request to Ollama to trigger loading the model
361pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
362 let uri = format!("{api_url}/api/generate");
363 let request = HttpRequest::builder()
364 .method(Method::POST)
365 .uri(uri)
366 .header("Content-Type", "application/json")
367 .body(AsyncBody::from(
368 serde_json::json!({
369 "model": model,
370 "keep_alive": "15m",
371 })
372 .to_string(),
373 ))?;
374
375 let mut response = client.send(request).await?;
376
377 if response.status().is_success() {
378 Ok(())
379 } else {
380 let mut body = String::new();
381 response.body_mut().read_to_string(&mut body).await?;
382 anyhow::bail!(
383 "Failed to connect to Ollama API: {} {}",
384 response.status(),
385 body,
386 );
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn parse_completion() {
396 let response = serde_json::json!({
397 "model": "llama3.2",
398 "created_at": "2023-12-12T14:13:43.416799Z",
399 "message": {
400 "role": "assistant",
401 "content": "Hello! How are you today?"
402 },
403 "done": true,
404 "total_duration": 5191566416u64,
405 "load_duration": 2154458,
406 "prompt_eval_count": 26,
407 "prompt_eval_duration": 383809000,
408 "eval_count": 298,
409 "eval_duration": 4799921000u64
410 });
411 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
412 }
413
414 #[test]
415 fn parse_streaming_completion() {
416 let partial = serde_json::json!({
417 "model": "llama3.2",
418 "created_at": "2023-08-04T08:52:19.385406455-07:00",
419 "message": {
420 "role": "assistant",
421 "content": "The",
422 "images": null
423 },
424 "done": false
425 });
426
427 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
428
429 let last = serde_json::json!({
430 "model": "llama3.2",
431 "created_at": "2023-08-04T19:22:45.499127Z",
432 "message": {
433 "role": "assistant",
434 "content": ""
435 },
436 "done": true,
437 "total_duration": 4883583458u64,
438 "load_duration": 1334875,
439 "prompt_eval_count": 26,
440 "prompt_eval_duration": 342546000,
441 "eval_count": 282,
442 "eval_duration": 4535599000u64
443 });
444
445 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
446 }
447
448 #[test]
449 fn parse_tool_call() {
450 let response = serde_json::json!({
451 "model": "llama3.2:3b",
452 "created_at": "2025-04-28T20:02:02.140489Z",
453 "message": {
454 "role": "assistant",
455 "content": "",
456 "tool_calls": [
457 {
458 "function": {
459 "name": "weather",
460 "arguments": {
461 "city": "london",
462 }
463 }
464 }
465 ]
466 },
467 "done_reason": "stop",
468 "done": true,
469 "total_duration": 2758629166u64,
470 "load_duration": 1770059875,
471 "prompt_eval_count": 147,
472 "prompt_eval_duration": 684637583,
473 "eval_count": 16,
474 "eval_duration": 302561917,
475 });
476
477 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
478 match result.message {
479 ChatMessage::Assistant {
480 content,
481 tool_calls,
482 images: _,
483 thinking,
484 } => {
485 assert!(content.is_empty());
486 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
487 assert!(thinking.is_none());
488 }
489 _ => panic!("Deserialized wrong role"),
490 }
491 }
492
493 #[test]
494 fn parse_show_model() {
495 let response = serde_json::json!({
496 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
497 "details": {
498 "parent_model": "",
499 "format": "gguf",
500 "family": "llama",
501 "families": ["llama"],
502 "parameter_size": "3.2B",
503 "quantization_level": "Q4_K_M"
504 },
505 "model_info": {
506 "general.architecture": "llama",
507 "general.basename": "Llama-3.2",
508 "general.file_type": 15,
509 "general.finetune": "Instruct",
510 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
511 "general.parameter_count": 3212749888u64,
512 "general.quantization_version": 2,
513 "general.size_label": "3B",
514 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
515 "general.type": "model",
516 "llama.attention.head_count": 24,
517 "llama.attention.head_count_kv": 8,
518 "llama.attention.key_length": 128,
519 "llama.attention.layer_norm_rms_epsilon": 0.00001,
520 "llama.attention.value_length": 128,
521 "llama.block_count": 28,
522 "llama.context_length": 131072,
523 "llama.embedding_length": 3072,
524 "llama.feed_forward_length": 8192,
525 "llama.rope.dimension_count": 128,
526 "llama.rope.freq_base": 500000,
527 "llama.vocab_size": 128256,
528 "tokenizer.ggml.bos_token_id": 128000,
529 "tokenizer.ggml.eos_token_id": 128009,
530 "tokenizer.ggml.merges": null,
531 "tokenizer.ggml.model": "gpt2",
532 "tokenizer.ggml.pre": "llama-bpe",
533 "tokenizer.ggml.token_type": null,
534 "tokenizer.ggml.tokens": null
535 },
536 "tensors": [
537 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
538 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
539 ],
540 "capabilities": ["completion", "tools"],
541 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
542 });
543
544 let result: ModelShow = serde_json::from_value(response).unwrap();
545 assert!(result.supports_tools());
546 assert!(result.capabilities.contains(&"tools".to_string()));
547 assert!(result.capabilities.contains(&"completion".to_string()));
548 }
549
550 #[test]
551 fn serialize_chat_request_with_images() {
552 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
553
554 let request = ChatRequest {
555 model: "llava".to_string(),
556 messages: vec![ChatMessage::User {
557 content: "What do you see in this image?".to_string(),
558 images: Some(vec![base64_image.to_string()]),
559 }],
560 stream: false,
561 keep_alive: KeepAlive::default(),
562 options: None,
563 think: None,
564 tools: vec![],
565 };
566
567 let serialized = serde_json::to_string(&request).unwrap();
568 assert!(serialized.contains("images"));
569 assert!(serialized.contains(base64_image));
570 }
571
572 #[test]
573 fn serialize_chat_request_without_images() {
574 let request = ChatRequest {
575 model: "llama3.2".to_string(),
576 messages: vec![ChatMessage::User {
577 content: "Hello, world!".to_string(),
578 images: None,
579 }],
580 stream: false,
581 keep_alive: KeepAlive::default(),
582 options: None,
583 think: None,
584 tools: vec![],
585 };
586
587 let serialized = serde_json::to_string(&request).unwrap();
588 assert!(!serialized.contains("images"));
589 }
590
591 #[test]
592 fn test_json_format_with_images() {
593 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
594
595 let request = ChatRequest {
596 model: "llava".to_string(),
597 messages: vec![ChatMessage::User {
598 content: "What do you see?".to_string(),
599 images: Some(vec![base64_image.to_string()]),
600 }],
601 stream: false,
602 keep_alive: KeepAlive::default(),
603 options: None,
604 think: None,
605 tools: vec![],
606 };
607
608 let serialized = serde_json::to_string(&request).unwrap();
609
610 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
611 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
612 assert_eq!(message_images.len(), 1);
613 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
614 }
615}