ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::KeepAlive;
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 12pub struct Model {
 13    pub name: String,
 14    pub display_name: Option<String>,
 15    pub max_tokens: u64,
 16    pub keep_alive: Option<KeepAlive>,
 17    pub supports_tools: Option<bool>,
 18    pub supports_vision: Option<bool>,
 19    pub supports_thinking: Option<bool>,
 20}
 21
 22fn get_max_tokens(name: &str) -> u64 {
 23    /// Default context length for unknown models.
 24    const DEFAULT_TOKENS: u64 = 4096;
 25    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 26    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 27    const MAXIMUM_TOKENS: u64 = 16384;
 28
 29    match name.split(':').next().unwrap() {
 30        "granite-code" | "phi" | "tinyllama" => 2048,
 31        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 32        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 33        "codellama" | "starcoder2" => 16384,
 34        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 35        | "qwen2" | "qwen2.5-coder" => 32768,
 36        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 37        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 38        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 39        "qwen3-coder" => 256000,
 40        _ => DEFAULT_TOKENS,
 41    }
 42    .clamp(1, MAXIMUM_TOKENS)
 43}
 44
 45impl Model {
 46    pub fn new(
 47        name: &str,
 48        display_name: Option<&str>,
 49        max_tokens: Option<u64>,
 50        supports_tools: Option<bool>,
 51        supports_vision: Option<bool>,
 52        supports_thinking: Option<bool>,
 53    ) -> Self {
 54        Self {
 55            name: name.to_owned(),
 56            display_name: display_name
 57                .map(ToString::to_string)
 58                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 59            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 60            keep_alive: Some(KeepAlive::indefinite()),
 61            supports_tools,
 62            supports_vision,
 63            supports_thinking,
 64        }
 65    }
 66
 67    pub fn id(&self) -> &str {
 68        &self.name
 69    }
 70
 71    pub fn display_name(&self) -> &str {
 72        self.display_name.as_ref().unwrap_or(&self.name)
 73    }
 74
 75    pub fn max_token_count(&self) -> u64 {
 76        self.max_tokens
 77    }
 78}
 79
 80#[derive(Serialize, Deserialize, Debug)]
 81#[serde(tag = "role", rename_all = "lowercase")]
 82pub enum ChatMessage {
 83    Assistant {
 84        content: String,
 85        tool_calls: Option<Vec<OllamaToolCall>>,
 86        #[serde(skip_serializing_if = "Option::is_none")]
 87        images: Option<Vec<String>>,
 88        thinking: Option<String>,
 89    },
 90    User {
 91        content: String,
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        images: Option<Vec<String>>,
 94    },
 95    System {
 96        content: String,
 97    },
 98    Tool {
 99        tool_name: String,
100        content: String,
101    },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105pub struct OllamaToolCall {
106    // TODO: Remove `Option` after most users have updated to Ollama v0.12.10,
107    // which was released on the 4th of November 2025
108    pub id: Option<String>,
109    pub function: OllamaFunctionCall,
110}
111
112#[derive(Serialize, Deserialize, Debug)]
113pub struct OllamaFunctionCall {
114    pub name: String,
115    pub arguments: Value,
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct OllamaFunctionTool {
120    pub name: String,
121    pub description: Option<String>,
122    pub parameters: Option<Value>,
123}
124
125#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
126#[serde(tag = "type", rename_all = "lowercase")]
127pub enum OllamaTool {
128    Function { function: OllamaFunctionTool },
129}
130
131#[derive(Serialize, Debug)]
132pub struct ChatRequest {
133    pub model: String,
134    pub messages: Vec<ChatMessage>,
135    pub stream: bool,
136    pub keep_alive: KeepAlive,
137    pub options: Option<ChatOptions>,
138    pub tools: Vec<OllamaTool>,
139    pub think: Option<bool>,
140}
141
142// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
143#[derive(Serialize, Default, Debug)]
144pub struct ChatOptions {
145    pub num_ctx: Option<u64>,
146    pub num_predict: Option<isize>,
147    pub stop: Option<Vec<String>>,
148    pub temperature: Option<f32>,
149    pub top_p: Option<f32>,
150}
151
152#[derive(Deserialize, Debug)]
153pub struct ChatResponseDelta {
154    pub model: String,
155    pub created_at: String,
156    pub message: ChatMessage,
157    pub done_reason: Option<String>,
158    pub done: bool,
159    pub prompt_eval_count: Option<u64>,
160    pub eval_count: Option<u64>,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct LocalModelsResponse {
165    pub models: Vec<LocalModelListing>,
166}
167
168#[derive(Serialize, Deserialize)]
169pub struct LocalModelListing {
170    pub name: String,
171    pub modified_at: String,
172    pub size: u64,
173    pub digest: String,
174    pub details: ModelDetails,
175}
176
177#[derive(Serialize, Deserialize)]
178pub struct LocalModel {
179    pub modelfile: String,
180    pub parameters: String,
181    pub template: String,
182    pub details: ModelDetails,
183}
184
185#[derive(Serialize, Deserialize)]
186pub struct ModelDetails {
187    pub format: String,
188    pub family: String,
189    pub families: Option<Vec<String>>,
190    pub parameter_size: String,
191    pub quantization_level: String,
192}
193
194#[derive(Debug)]
195pub struct ModelShow {
196    pub capabilities: Vec<String>,
197    pub context_length: Option<u64>,
198    pub architecture: Option<String>,
199}
200
201impl<'de> Deserialize<'de> for ModelShow {
202    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
203    where
204        D: serde::Deserializer<'de>,
205    {
206        use serde::de::{self, MapAccess, Visitor};
207        use std::fmt;
208
209        struct ModelShowVisitor;
210
211        impl<'de> Visitor<'de> for ModelShowVisitor {
212            type Value = ModelShow;
213
214            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
215                formatter.write_str("a ModelShow object")
216            }
217
218            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
219            where
220                A: MapAccess<'de>,
221            {
222                let mut capabilities: Vec<String> = Vec::new();
223                let mut architecture: Option<String> = None;
224                let mut context_length: Option<u64> = None;
225
226                while let Some(key) = map.next_key::<String>()? {
227                    match key.as_str() {
228                        "capabilities" => {
229                            capabilities = map.next_value()?;
230                        }
231                        "model_info" => {
232                            let model_info: Value = map.next_value()?;
233                            if let Value::Object(obj) = model_info {
234                                architecture = obj
235                                    .get("general.architecture")
236                                    .and_then(|v| v.as_str())
237                                    .map(String::from);
238
239                                if let Some(arch) = &architecture {
240                                    context_length = obj
241                                        .get(&format!("{}.context_length", arch))
242                                        .and_then(|v| v.as_u64());
243                                }
244                            }
245                        }
246                        _ => {
247                            let _: de::IgnoredAny = map.next_value()?;
248                        }
249                    }
250                }
251
252                Ok(ModelShow {
253                    capabilities,
254                    context_length,
255                    architecture,
256                })
257            }
258        }
259
260        deserializer.deserialize_map(ModelShowVisitor)
261    }
262}
263
264impl ModelShow {
265    pub fn supports_tools(&self) -> bool {
266        // .contains expects &String, which would require an additional allocation
267        self.capabilities.iter().any(|v| v == "tools")
268    }
269
270    pub fn supports_vision(&self) -> bool {
271        self.capabilities.iter().any(|v| v == "vision")
272    }
273
274    pub fn supports_thinking(&self) -> bool {
275        self.capabilities.iter().any(|v| v == "thinking")
276    }
277}
278
279pub async fn stream_chat_completion(
280    client: &dyn HttpClient,
281    api_url: &str,
282    api_key: Option<&str>,
283    request: ChatRequest,
284) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
285    let uri = format!("{api_url}/api/chat");
286    let request = HttpRequest::builder()
287        .method(Method::POST)
288        .uri(uri)
289        .header("Content-Type", "application/json")
290        .when_some(api_key, |builder, api_key| {
291            builder.header("Authorization", format!("Bearer {api_key}"))
292        })
293        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
294
295    let mut response = client.send(request).await?;
296    if response.status().is_success() {
297        let reader = BufReader::new(response.into_body());
298
299        Ok(reader
300            .lines()
301            .map(|line| match line {
302                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
303                Err(e) => Err(e.into()),
304            })
305            .boxed())
306    } else {
307        let mut body = String::new();
308        response.body_mut().read_to_string(&mut body).await?;
309        anyhow::bail!(
310            "Failed to connect to Ollama API: {} {}",
311            response.status(),
312            body,
313        );
314    }
315}
316
317pub async fn get_models(
318    client: &dyn HttpClient,
319    api_url: &str,
320    api_key: Option<&str>,
321) -> Result<Vec<LocalModelListing>> {
322    let uri = format!("{api_url}/api/tags");
323    let request = HttpRequest::builder()
324        .method(Method::GET)
325        .uri(uri)
326        .header("Accept", "application/json")
327        .when_some(api_key, |builder, api_key| {
328            builder.header("Authorization", format!("Bearer {api_key}"))
329        })
330        .body(AsyncBody::default())?;
331
332    let mut response = client.send(request).await?;
333
334    let mut body = String::new();
335    response.body_mut().read_to_string(&mut body).await?;
336
337    anyhow::ensure!(
338        response.status().is_success(),
339        "Failed to connect to Ollama API: {} {}",
340        response.status(),
341        body,
342    );
343    let response: LocalModelsResponse =
344        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
345    Ok(response.models)
346}
347
348/// Fetch details of a model, used to determine model capabilities
349pub async fn show_model(
350    client: &dyn HttpClient,
351    api_url: &str,
352    api_key: Option<&str>,
353    model: &str,
354) -> Result<ModelShow> {
355    let uri = format!("{api_url}/api/show");
356    let request = HttpRequest::builder()
357        .method(Method::POST)
358        .uri(uri)
359        .header("Content-Type", "application/json")
360        .when_some(api_key, |builder, api_key| {
361            builder.header("Authorization", format!("Bearer {api_key}"))
362        })
363        .body(AsyncBody::from(
364            serde_json::json!({ "model": model }).to_string(),
365        ))?;
366
367    let mut response = client.send(request).await?;
368    let mut body = String::new();
369    response.body_mut().read_to_string(&mut body).await?;
370
371    anyhow::ensure!(
372        response.status().is_success(),
373        "Failed to connect to Ollama API: {} {}",
374        response.status(),
375        body,
376    );
377    let details: ModelShow = serde_json::from_str(body.as_str())?;
378    Ok(details)
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn parse_completion() {
387        let response = serde_json::json!({
388        "model": "llama3.2",
389        "created_at": "2023-12-12T14:13:43.416799Z",
390        "message": {
391            "role": "assistant",
392            "content": "Hello! How are you today?"
393        },
394        "done": true,
395        "total_duration": 5191566416u64,
396        "load_duration": 2154458,
397        "prompt_eval_count": 26,
398        "prompt_eval_duration": 383809000,
399        "eval_count": 298,
400        "eval_duration": 4799921000u64
401        });
402        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
403    }
404
405    #[test]
406    fn parse_streaming_completion() {
407        let partial = serde_json::json!({
408        "model": "llama3.2",
409        "created_at": "2023-08-04T08:52:19.385406455-07:00",
410        "message": {
411            "role": "assistant",
412            "content": "The",
413            "images": null
414        },
415        "done": false
416        });
417
418        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
419
420        let last = serde_json::json!({
421        "model": "llama3.2",
422        "created_at": "2023-08-04T19:22:45.499127Z",
423        "message": {
424            "role": "assistant",
425            "content": ""
426        },
427        "done": true,
428        "total_duration": 4883583458u64,
429        "load_duration": 1334875,
430        "prompt_eval_count": 26,
431        "prompt_eval_duration": 342546000,
432        "eval_count": 282,
433        "eval_duration": 4535599000u64
434        });
435
436        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
437    }
438
439    #[test]
440    fn parse_tool_call() {
441        let response = serde_json::json!({
442            "model": "llama3.2:3b",
443            "created_at": "2025-04-28T20:02:02.140489Z",
444            "message": {
445                "role": "assistant",
446                "content": "",
447                "tool_calls": [
448                    {
449                        "id": "call_llama3.2:3b_145155",
450                        "function": {
451                            "name": "weather",
452                            "arguments": {
453                                "city": "london",
454                            }
455                        }
456                    }
457                ]
458            },
459            "done_reason": "stop",
460            "done": true,
461            "total_duration": 2758629166u64,
462            "load_duration": 1770059875,
463            "prompt_eval_count": 147,
464            "prompt_eval_duration": 684637583,
465            "eval_count": 16,
466            "eval_duration": 302561917,
467        });
468
469        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
470        match result.message {
471            ChatMessage::Assistant {
472                content,
473                tool_calls,
474                images: _,
475                thinking,
476            } => {
477                assert!(content.is_empty());
478                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
479                assert!(thinking.is_none());
480            }
481            _ => panic!("Deserialized wrong role"),
482        }
483    }
484
485    // Backwards compatibility with Ollama versions prior to v0.12.10 November 2025
486    // This test is a copy of `parse_tool_call()` with the `id` field omitted.
487    #[test]
488    fn parse_tool_call_pre_0_12_10() {
489        let response = serde_json::json!({
490            "model": "llama3.2:3b",
491            "created_at": "2025-04-28T20:02:02.140489Z",
492            "message": {
493                "role": "assistant",
494                "content": "",
495                "tool_calls": [
496                    {
497                        "function": {
498                            "name": "weather",
499                            "arguments": {
500                                "city": "london",
501                            }
502                        }
503                    }
504                ]
505            },
506            "done_reason": "stop",
507            "done": true,
508            "total_duration": 2758629166u64,
509            "load_duration": 1770059875,
510            "prompt_eval_count": 147,
511            "prompt_eval_duration": 684637583,
512            "eval_count": 16,
513            "eval_duration": 302561917,
514        });
515
516        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
517        match result.message {
518            ChatMessage::Assistant {
519                content,
520                tool_calls: Some(tool_calls),
521                images: _,
522                thinking,
523            } => {
524                assert!(content.is_empty());
525                assert!(thinking.is_none());
526
527                // When the `Option` around `id` is removed, this test should complain
528                // and be subsequently deleted in favor of `parse_tool_call()`
529                assert!(tool_calls.first().is_some_and(|call| call.id.is_none()))
530            }
531            _ => panic!("Deserialized wrong role"),
532        }
533    }
534
535    #[test]
536    fn parse_show_model() {
537        let response = serde_json::json!({
538            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
539            "details": {
540                "parent_model": "",
541                "format": "gguf",
542                "family": "llama",
543                "families": ["llama"],
544                "parameter_size": "3.2B",
545                "quantization_level": "Q4_K_M"
546            },
547            "model_info": {
548                "general.architecture": "llama",
549                "general.basename": "Llama-3.2",
550                "general.file_type": 15,
551                "general.finetune": "Instruct",
552                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
553                "general.parameter_count": 3212749888u64,
554                "general.quantization_version": 2,
555                "general.size_label": "3B",
556                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
557                "general.type": "model",
558                "llama.attention.head_count": 24,
559                "llama.attention.head_count_kv": 8,
560                "llama.attention.key_length": 128,
561                "llama.attention.layer_norm_rms_epsilon": 0.00001,
562                "llama.attention.value_length": 128,
563                "llama.block_count": 28,
564                "llama.context_length": 131072,
565                "llama.embedding_length": 3072,
566                "llama.feed_forward_length": 8192,
567                "llama.rope.dimension_count": 128,
568                "llama.rope.freq_base": 500000,
569                "llama.vocab_size": 128256,
570                "tokenizer.ggml.bos_token_id": 128000,
571                "tokenizer.ggml.eos_token_id": 128009,
572                "tokenizer.ggml.merges": null,
573                "tokenizer.ggml.model": "gpt2",
574                "tokenizer.ggml.pre": "llama-bpe",
575                "tokenizer.ggml.token_type": null,
576                "tokenizer.ggml.tokens": null
577            },
578            "tensors": [
579                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
580                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
581            ],
582            "capabilities": ["completion", "tools"],
583            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
584        });
585
586        let result: ModelShow = serde_json::from_value(response).unwrap();
587        assert!(result.supports_tools());
588        assert!(result.capabilities.contains(&"tools".to_string()));
589        assert!(result.capabilities.contains(&"completion".to_string()));
590
591        assert_eq!(result.architecture, Some("llama".to_string()));
592        assert_eq!(result.context_length, Some(131072));
593    }
594
595    #[test]
596    fn serialize_chat_request_with_images() {
597        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
598
599        let request = ChatRequest {
600            model: "llava".to_string(),
601            messages: vec![ChatMessage::User {
602                content: "What do you see in this image?".to_string(),
603                images: Some(vec![base64_image.to_string()]),
604            }],
605            stream: false,
606            keep_alive: KeepAlive::default(),
607            options: None,
608            think: None,
609            tools: vec![],
610        };
611
612        let serialized = serde_json::to_string(&request).unwrap();
613        assert!(serialized.contains("images"));
614        assert!(serialized.contains(base64_image));
615    }
616
617    #[test]
618    fn serialize_chat_request_without_images() {
619        let request = ChatRequest {
620            model: "llama3.2".to_string(),
621            messages: vec![ChatMessage::User {
622                content: "Hello, world!".to_string(),
623                images: None,
624            }],
625            stream: false,
626            keep_alive: KeepAlive::default(),
627            options: None,
628            think: None,
629            tools: vec![],
630        };
631
632        let serialized = serde_json::to_string(&request).unwrap();
633        assert!(!serialized.contains("images"));
634    }
635
636    #[test]
637    fn test_json_format_with_images() {
638        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
639
640        let request = ChatRequest {
641            model: "llava".to_string(),
642            messages: vec![ChatMessage::User {
643                content: "What do you see?".to_string(),
644                images: Some(vec![base64_image.to_string()]),
645            }],
646            stream: false,
647            keep_alive: KeepAlive::default(),
648            options: None,
649            think: None,
650            tools: vec![],
651        };
652
653        let serialized = serde_json::to_string(&request).unwrap();
654
655        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
656        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
657        assert_eq!(message_images.len(), 1);
658        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
659    }
660}