1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::time::Duration;
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(untagged)]
13pub enum KeepAlive {
14 /// Keep model alive for N seconds
15 Seconds(isize),
16 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
17 Duration(String),
18}
19
20impl KeepAlive {
21 /// Keep model alive until a new model is loaded or until Ollama shuts down
22 fn indefinite() -> Self {
23 Self::Seconds(-1)
24 }
25}
26
27impl Default for KeepAlive {
28 fn default() -> Self {
29 Self::indefinite()
30 }
31}
32
33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
35pub struct Model {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: u64,
39 pub keep_alive: Option<KeepAlive>,
40 pub supports_tools: Option<bool>,
41 pub supports_vision: Option<bool>,
42 pub supports_thinking: Option<bool>,
43}
44
45fn get_max_tokens(name: &str) -> u64 {
46 /// Default context length for unknown models.
47 const DEFAULT_TOKENS: u64 = 4096;
48 /// Magic number. Lets many Ollama models work with ~16GB of ram.
49 /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
50 const MAXIMUM_TOKENS: u64 = 16384;
51
52 match name.split(':').next().unwrap() {
53 "granite-code" | "phi" | "tinyllama" => 2048,
54 "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
55 "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
56 "codellama" | "starcoder2" => 16384,
57 "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
58 | "qwen2" | "qwen2.5-coder" => 32768,
59 "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
60 | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
61 | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
62 _ => DEFAULT_TOKENS,
63 }
64 .clamp(1, MAXIMUM_TOKENS)
65}
66
67impl Model {
68 pub fn new(
69 name: &str,
70 display_name: Option<&str>,
71 max_tokens: Option<u64>,
72 supports_tools: Option<bool>,
73 supports_vision: Option<bool>,
74 supports_thinking: Option<bool>,
75 ) -> Self {
76 Self {
77 name: name.to_owned(),
78 display_name: display_name
79 .map(ToString::to_string)
80 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
81 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
82 keep_alive: Some(KeepAlive::indefinite()),
83 supports_tools,
84 supports_vision,
85 supports_thinking,
86 }
87 }
88
89 pub fn id(&self) -> &str {
90 &self.name
91 }
92
93 pub fn display_name(&self) -> &str {
94 self.display_name.as_ref().unwrap_or(&self.name)
95 }
96
97 pub fn max_token_count(&self) -> u64 {
98 self.max_tokens
99 }
100}
101
102#[derive(Serialize, Deserialize, Debug)]
103#[serde(tag = "role", rename_all = "lowercase")]
104pub enum ChatMessage {
105 Assistant {
106 content: String,
107 tool_calls: Option<Vec<OllamaToolCall>>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 images: Option<Vec<String>>,
110 thinking: Option<String>,
111 },
112 User {
113 content: String,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 images: Option<Vec<String>>,
116 },
117 System {
118 content: String,
119 },
120 Tool {
121 tool_name: String,
122 content: String,
123 },
124}
125
126#[derive(Serialize, Deserialize, Debug)]
127#[serde(rename_all = "lowercase")]
128pub enum OllamaToolCall {
129 Function(OllamaFunctionCall),
130}
131
132#[derive(Serialize, Deserialize, Debug)]
133pub struct OllamaFunctionCall {
134 pub name: String,
135 pub arguments: Value,
136}
137
138#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
139pub struct OllamaFunctionTool {
140 pub name: String,
141 pub description: Option<String>,
142 pub parameters: Option<Value>,
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146#[serde(tag = "type", rename_all = "lowercase")]
147pub enum OllamaTool {
148 Function { function: OllamaFunctionTool },
149}
150
151#[derive(Serialize, Debug)]
152pub struct ChatRequest {
153 pub model: String,
154 pub messages: Vec<ChatMessage>,
155 pub stream: bool,
156 pub keep_alive: KeepAlive,
157 pub options: Option<ChatOptions>,
158 pub tools: Vec<OllamaTool>,
159 pub think: Option<bool>,
160}
161
162impl ChatRequest {
163 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
164 self.stream = false;
165 self.tools = tools;
166 self
167 }
168}
169
170// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
171#[derive(Serialize, Default, Debug)]
172pub struct ChatOptions {
173 pub num_ctx: Option<u64>,
174 pub num_predict: Option<isize>,
175 pub stop: Option<Vec<String>>,
176 pub temperature: Option<f32>,
177 pub top_p: Option<f32>,
178}
179
180#[derive(Deserialize, Debug)]
181pub struct ChatResponseDelta {
182 #[allow(unused)]
183 pub model: String,
184 #[allow(unused)]
185 pub created_at: String,
186 pub message: ChatMessage,
187 #[allow(unused)]
188 pub done_reason: Option<String>,
189 #[allow(unused)]
190 pub done: bool,
191 pub prompt_eval_count: Option<u64>,
192 pub eval_count: Option<u64>,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct LocalModelsResponse {
197 pub models: Vec<LocalModelListing>,
198}
199
200#[derive(Serialize, Deserialize)]
201pub struct LocalModelListing {
202 pub name: String,
203 pub modified_at: String,
204 pub size: u64,
205 pub digest: String,
206 pub details: ModelDetails,
207}
208
209#[derive(Serialize, Deserialize)]
210pub struct LocalModel {
211 pub modelfile: String,
212 pub parameters: String,
213 pub template: String,
214 pub details: ModelDetails,
215}
216
217#[derive(Serialize, Deserialize)]
218pub struct ModelDetails {
219 pub format: String,
220 pub family: String,
221 pub families: Option<Vec<String>>,
222 pub parameter_size: String,
223 pub quantization_level: String,
224}
225
226#[derive(Deserialize, Debug)]
227pub struct ModelShow {
228 #[serde(default)]
229 pub capabilities: Vec<String>,
230}
231
232impl ModelShow {
233 pub fn supports_tools(&self) -> bool {
234 // .contains expects &String, which would require an additional allocation
235 self.capabilities.iter().any(|v| v == "tools")
236 }
237
238 pub fn supports_vision(&self) -> bool {
239 self.capabilities.iter().any(|v| v == "vision")
240 }
241
242 pub fn supports_thinking(&self) -> bool {
243 self.capabilities.iter().any(|v| v == "thinking")
244 }
245}
246
247pub async fn complete(
248 client: &dyn HttpClient,
249 api_url: &str,
250 request: ChatRequest,
251) -> Result<ChatResponseDelta> {
252 let uri = format!("{api_url}/api/chat");
253 let request_builder = HttpRequest::builder()
254 .method(Method::POST)
255 .uri(uri)
256 .header("Content-Type", "application/json");
257
258 let serialized_request = serde_json::to_string(&request)?;
259 let request = request_builder.body(AsyncBody::from(serialized_request))?;
260
261 let mut response = client.send(request).await?;
262
263 let mut body = Vec::new();
264 response.body_mut().read_to_end(&mut body).await?;
265
266 if response.status().is_success() {
267 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
268 Ok(response_message)
269 } else {
270 let body_str = std::str::from_utf8(&body)?;
271 anyhow::bail!(
272 "Failed to connect to API: {} {}",
273 response.status(),
274 body_str
275 );
276 }
277}
278
279pub async fn stream_chat_completion(
280 client: &dyn HttpClient,
281 api_url: &str,
282 api_key: Option<&str>,
283 request: ChatRequest,
284) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
285 let uri = format!("{api_url}/api/chat");
286 let mut request_builder = http::Request::builder()
287 .method(Method::POST)
288 .uri(uri)
289 .header("Content-Type", "application/json");
290
291 if let Some(api_key) = api_key {
292 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
293 }
294
295 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
296 let mut response = client.send(request).await?;
297 if response.status().is_success() {
298 let reader = BufReader::new(response.into_body());
299
300 Ok(reader
301 .lines()
302 .map(|line| match line {
303 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
304 Err(e) => Err(e.into()),
305 })
306 .boxed())
307 } else {
308 let mut body = String::new();
309 response.body_mut().read_to_string(&mut body).await?;
310 anyhow::bail!(
311 "Failed to connect to Ollama API: {} {}",
312 response.status(),
313 body,
314 );
315 }
316}
317
318pub async fn get_models(
319 client: &dyn HttpClient,
320 api_url: &str,
321 api_key: Option<&str>,
322 _: Option<Duration>,
323) -> Result<Vec<LocalModelListing>> {
324 let uri = format!("{api_url}/api/tags");
325 let mut request_builder = HttpRequest::builder()
326 .method(Method::GET)
327 .uri(uri)
328 .header("Accept", "application/json");
329
330 if let Some(api_key) = api_key {
331 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
332 }
333
334 let request = request_builder.body(AsyncBody::default())?;
335
336 let mut response = client.send(request).await?;
337
338 let mut body = String::new();
339 response.body_mut().read_to_string(&mut body).await?;
340
341 anyhow::ensure!(
342 response.status().is_success(),
343 "Failed to connect to Ollama API: {} {}",
344 response.status(),
345 body,
346 );
347 let response: LocalModelsResponse =
348 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
349 Ok(response.models)
350}
351
352/// Fetch details of a model, used to determine model capabilities
353pub async fn show_model(
354 client: &dyn HttpClient,
355 api_url: &str,
356 api_key: Option<&str>,
357 model: &str,
358) -> Result<ModelShow> {
359 let uri = format!("{api_url}/api/show");
360 let mut request_builder = HttpRequest::builder()
361 .method(Method::POST)
362 .uri(uri)
363 .header("Content-Type", "application/json");
364
365 if let Some(api_key) = api_key {
366 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
367 }
368
369 let request = request_builder.body(AsyncBody::from(
370 serde_json::json!({ "model": model }).to_string(),
371 ))?;
372
373 let mut response = client.send(request).await?;
374 let mut body = String::new();
375 response.body_mut().read_to_string(&mut body).await?;
376
377 anyhow::ensure!(
378 response.status().is_success(),
379 "Failed to connect to Ollama API: {} {}",
380 response.status(),
381 body,
382 );
383 let details: ModelShow = serde_json::from_str(body.as_str())?;
384 Ok(details)
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn parse_completion() {
393 let response = serde_json::json!({
394 "model": "llama3.2",
395 "created_at": "2023-12-12T14:13:43.416799Z",
396 "message": {
397 "role": "assistant",
398 "content": "Hello! How are you today?"
399 },
400 "done": true,
401 "total_duration": 5191566416u64,
402 "load_duration": 2154458,
403 "prompt_eval_count": 26,
404 "prompt_eval_duration": 383809000,
405 "eval_count": 298,
406 "eval_duration": 4799921000u64
407 });
408 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
409 }
410
411 #[test]
412 fn parse_streaming_completion() {
413 let partial = serde_json::json!({
414 "model": "llama3.2",
415 "created_at": "2023-08-04T08:52:19.385406455-07:00",
416 "message": {
417 "role": "assistant",
418 "content": "The",
419 "images": null
420 },
421 "done": false
422 });
423
424 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
425
426 let last = serde_json::json!({
427 "model": "llama3.2",
428 "created_at": "2023-08-04T19:22:45.499127Z",
429 "message": {
430 "role": "assistant",
431 "content": ""
432 },
433 "done": true,
434 "total_duration": 4883583458u64,
435 "load_duration": 1334875,
436 "prompt_eval_count": 26,
437 "prompt_eval_duration": 342546000,
438 "eval_count": 282,
439 "eval_duration": 4535599000u64
440 });
441
442 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
443 }
444
445 #[test]
446 fn parse_tool_call() {
447 let response = serde_json::json!({
448 "model": "llama3.2:3b",
449 "created_at": "2025-04-28T20:02:02.140489Z",
450 "message": {
451 "role": "assistant",
452 "content": "",
453 "tool_calls": [
454 {
455 "function": {
456 "name": "weather",
457 "arguments": {
458 "city": "london",
459 }
460 }
461 }
462 ]
463 },
464 "done_reason": "stop",
465 "done": true,
466 "total_duration": 2758629166u64,
467 "load_duration": 1770059875,
468 "prompt_eval_count": 147,
469 "prompt_eval_duration": 684637583,
470 "eval_count": 16,
471 "eval_duration": 302561917,
472 });
473
474 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
475 match result.message {
476 ChatMessage::Assistant {
477 content,
478 tool_calls,
479 images: _,
480 thinking,
481 } => {
482 assert!(content.is_empty());
483 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
484 assert!(thinking.is_none());
485 }
486 _ => panic!("Deserialized wrong role"),
487 }
488 }
489
490 #[test]
491 fn parse_show_model() {
492 let response = serde_json::json!({
493 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
494 "details": {
495 "parent_model": "",
496 "format": "gguf",
497 "family": "llama",
498 "families": ["llama"],
499 "parameter_size": "3.2B",
500 "quantization_level": "Q4_K_M"
501 },
502 "model_info": {
503 "general.architecture": "llama",
504 "general.basename": "Llama-3.2",
505 "general.file_type": 15,
506 "general.finetune": "Instruct",
507 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
508 "general.parameter_count": 3212749888u64,
509 "general.quantization_version": 2,
510 "general.size_label": "3B",
511 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
512 "general.type": "model",
513 "llama.attention.head_count": 24,
514 "llama.attention.head_count_kv": 8,
515 "llama.attention.key_length": 128,
516 "llama.attention.layer_norm_rms_epsilon": 0.00001,
517 "llama.attention.value_length": 128,
518 "llama.block_count": 28,
519 "llama.context_length": 131072,
520 "llama.embedding_length": 3072,
521 "llama.feed_forward_length": 8192,
522 "llama.rope.dimension_count": 128,
523 "llama.rope.freq_base": 500000,
524 "llama.vocab_size": 128256,
525 "tokenizer.ggml.bos_token_id": 128000,
526 "tokenizer.ggml.eos_token_id": 128009,
527 "tokenizer.ggml.merges": null,
528 "tokenizer.ggml.model": "gpt2",
529 "tokenizer.ggml.pre": "llama-bpe",
530 "tokenizer.ggml.token_type": null,
531 "tokenizer.ggml.tokens": null
532 },
533 "tensors": [
534 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
535 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
536 ],
537 "capabilities": ["completion", "tools"],
538 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
539 });
540
541 let result: ModelShow = serde_json::from_value(response).unwrap();
542 assert!(result.supports_tools());
543 assert!(result.capabilities.contains(&"tools".to_string()));
544 assert!(result.capabilities.contains(&"completion".to_string()));
545 }
546
547 #[test]
548 fn serialize_chat_request_with_images() {
549 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
550
551 let request = ChatRequest {
552 model: "llava".to_string(),
553 messages: vec![ChatMessage::User {
554 content: "What do you see in this image?".to_string(),
555 images: Some(vec![base64_image.to_string()]),
556 }],
557 stream: false,
558 keep_alive: KeepAlive::default(),
559 options: None,
560 think: None,
561 tools: vec![],
562 };
563
564 let serialized = serde_json::to_string(&request).unwrap();
565 assert!(serialized.contains("images"));
566 assert!(serialized.contains(base64_image));
567 }
568
569 #[test]
570 fn serialize_chat_request_without_images() {
571 let request = ChatRequest {
572 model: "llama3.2".to_string(),
573 messages: vec![ChatMessage::User {
574 content: "Hello, world!".to_string(),
575 images: None,
576 }],
577 stream: false,
578 keep_alive: KeepAlive::default(),
579 options: None,
580 think: None,
581 tools: vec![],
582 };
583
584 let serialized = serde_json::to_string(&request).unwrap();
585 assert!(!serialized.contains("images"));
586 }
587
588 #[test]
589 fn test_json_format_with_images() {
590 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
591
592 let request = ChatRequest {
593 model: "llava".to_string(),
594 messages: vec![ChatMessage::User {
595 content: "What do you see?".to_string(),
596 images: Some(vec![base64_image.to_string()]),
597 }],
598 stream: false,
599 keep_alive: KeepAlive::default(),
600 options: None,
601 think: None,
602 tools: vec![],
603 };
604
605 let serialized = serde_json::to_string(&request).unwrap();
606
607 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
608 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
609 assert_eq!(message_images.len(), 1);
610 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
611 }
612}