ollama.rs

  1use anyhow::{Context, Result};
  2
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::KeepAlive;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(name: &str) -> u64 {
 24    /// Default context length for unknown models.
 25    const DEFAULT_TOKENS: u64 = 4096;
 26    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 27    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 28    const MAXIMUM_TOKENS: u64 = 16384;
 29
 30    match name.split(':').next().unwrap() {
 31        "granite-code" | "phi" | "tinyllama" => 2048,
 32        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 33        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 34        "codellama" | "starcoder2" => 16384,
 35        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 36        | "qwen2" | "qwen2.5-coder" => 32768,
 37        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 38        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 39        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 40        "qwen3-coder" => 256000,
 41        _ => DEFAULT_TOKENS,
 42    }
 43    .clamp(1, MAXIMUM_TOKENS)
 44}
 45
 46impl Model {
 47    pub fn new(
 48        name: &str,
 49        display_name: Option<&str>,
 50        max_tokens: Option<u64>,
 51        supports_tools: Option<bool>,
 52        supports_vision: Option<bool>,
 53        supports_thinking: Option<bool>,
 54    ) -> Self {
 55        Self {
 56            name: name.to_owned(),
 57            display_name: display_name
 58                .map(ToString::to_string)
 59                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 60            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 61            keep_alive: Some(KeepAlive::indefinite()),
 62            supports_tools,
 63            supports_vision,
 64            supports_thinking,
 65        }
 66    }
 67
 68    pub fn id(&self) -> &str {
 69        &self.name
 70    }
 71
 72    pub fn display_name(&self) -> &str {
 73        self.display_name.as_ref().unwrap_or(&self.name)
 74    }
 75
 76    pub fn max_token_count(&self) -> u64 {
 77        self.max_tokens
 78    }
 79}
 80
 81#[derive(Serialize, Deserialize, Debug)]
 82#[serde(tag = "role", rename_all = "lowercase")]
 83pub enum ChatMessage {
 84    Assistant {
 85        content: String,
 86        tool_calls: Option<Vec<OllamaToolCall>>,
 87        #[serde(skip_serializing_if = "Option::is_none")]
 88        images: Option<Vec<String>>,
 89        thinking: Option<String>,
 90    },
 91    User {
 92        content: String,
 93        #[serde(skip_serializing_if = "Option::is_none")]
 94        images: Option<Vec<String>>,
 95    },
 96    System {
 97        content: String,
 98    },
 99    Tool {
100        tool_name: String,
101        content: String,
102    },
103}
104
105#[derive(Serialize, Deserialize, Debug)]
106pub struct OllamaToolCall {
107    // TODO: Remove `Option` after most users have updated to Ollama v0.12.10,
108    // which was released on the 4th of November 2025
109    pub id: Option<String>,
110    pub function: OllamaFunctionCall,
111}
112
113#[derive(Serialize, Deserialize, Debug)]
114pub struct OllamaFunctionCall {
115    pub name: String,
116    pub arguments: Value,
117}
118
119#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
120pub struct OllamaFunctionTool {
121    pub name: String,
122    pub description: Option<String>,
123    pub parameters: Option<Value>,
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127#[serde(tag = "type", rename_all = "lowercase")]
128pub enum OllamaTool {
129    Function { function: OllamaFunctionTool },
130}
131
132#[derive(Serialize, Debug)]
133pub struct ChatRequest {
134    pub model: String,
135    pub messages: Vec<ChatMessage>,
136    pub stream: bool,
137    pub keep_alive: KeepAlive,
138    pub options: Option<ChatOptions>,
139    pub tools: Vec<OllamaTool>,
140    pub think: Option<bool>,
141}
142
143// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
144#[derive(Serialize, Default, Debug)]
145pub struct ChatOptions {
146    pub num_ctx: Option<u64>,
147    pub num_predict: Option<isize>,
148    pub stop: Option<Vec<String>>,
149    pub temperature: Option<f32>,
150    pub top_p: Option<f32>,
151}
152
153#[derive(Deserialize, Debug)]
154pub struct ChatResponseDelta {
155    pub model: String,
156    pub created_at: String,
157    pub message: ChatMessage,
158    pub done_reason: Option<String>,
159    pub done: bool,
160    pub prompt_eval_count: Option<u64>,
161    pub eval_count: Option<u64>,
162}
163
164#[derive(Serialize, Deserialize)]
165pub struct LocalModelsResponse {
166    pub models: Vec<LocalModelListing>,
167}
168
169#[derive(Serialize, Deserialize)]
170pub struct LocalModelListing {
171    pub name: String,
172    pub modified_at: String,
173    pub size: u64,
174    pub digest: String,
175    pub details: ModelDetails,
176}
177
178#[derive(Serialize, Deserialize)]
179pub struct LocalModel {
180    pub modelfile: String,
181    pub parameters: String,
182    pub template: String,
183    pub details: ModelDetails,
184}
185
186#[derive(Serialize, Deserialize)]
187pub struct ModelDetails {
188    pub format: String,
189    pub family: String,
190    pub families: Option<Vec<String>>,
191    pub parameter_size: String,
192    pub quantization_level: String,
193}
194
195#[derive(Debug)]
196pub struct ModelShow {
197    pub capabilities: Vec<String>,
198    pub context_length: Option<u64>,
199    pub architecture: Option<String>,
200}
201
202impl<'de> Deserialize<'de> for ModelShow {
203    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
204    where
205        D: serde::Deserializer<'de>,
206    {
207        use serde::de::{self, MapAccess, Visitor};
208        use std::fmt;
209
210        struct ModelShowVisitor;
211
212        impl<'de> Visitor<'de> for ModelShowVisitor {
213            type Value = ModelShow;
214
215            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
216                formatter.write_str("a ModelShow object")
217            }
218
219            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
220            where
221                A: MapAccess<'de>,
222            {
223                let mut capabilities: Vec<String> = Vec::new();
224                let mut architecture: Option<String> = None;
225                let mut context_length: Option<u64> = None;
226
227                while let Some(key) = map.next_key::<String>()? {
228                    match key.as_str() {
229                        "capabilities" => {
230                            capabilities = map.next_value()?;
231                        }
232                        "model_info" => {
233                            let model_info: Value = map.next_value()?;
234                            if let Value::Object(obj) = model_info {
235                                architecture = obj
236                                    .get("general.architecture")
237                                    .and_then(|v| v.as_str())
238                                    .map(String::from);
239
240                                if let Some(arch) = &architecture {
241                                    context_length = obj
242                                        .get(&format!("{}.context_length", arch))
243                                        .and_then(|v| v.as_u64());
244                                }
245                            }
246                        }
247                        _ => {
248                            let _: de::IgnoredAny = map.next_value()?;
249                        }
250                    }
251                }
252
253                Ok(ModelShow {
254                    capabilities,
255                    context_length,
256                    architecture,
257                })
258            }
259        }
260
261        deserializer.deserialize_map(ModelShowVisitor)
262    }
263}
264
265impl ModelShow {
266    pub fn supports_tools(&self) -> bool {
267        // .contains expects &String, which would require an additional allocation
268        self.capabilities.iter().any(|v| v == "tools")
269    }
270
271    pub fn supports_vision(&self) -> bool {
272        self.capabilities.iter().any(|v| v == "vision")
273    }
274
275    pub fn supports_thinking(&self) -> bool {
276        self.capabilities.iter().any(|v| v == "thinking")
277    }
278}
279
280pub async fn stream_chat_completion(
281    client: &dyn HttpClient,
282    api_url: &str,
283    api_key: Option<&str>,
284    request: ChatRequest,
285) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
286    let uri = format!("{api_url}/api/chat");
287    let request = HttpRequest::builder()
288        .method(Method::POST)
289        .uri(uri)
290        .header("Content-Type", "application/json")
291        .when_some(api_key, |builder, api_key| {
292            builder.header("Authorization", format!("Bearer {api_key}"))
293        })
294        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
295
296    let mut response = client.send(request).await?;
297    if response.status().is_success() {
298        let reader = BufReader::new(response.into_body());
299
300        Ok(reader
301            .lines()
302            .map(|line| match line {
303                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
304                Err(e) => Err(e.into()),
305            })
306            .boxed())
307    } else {
308        let mut body = String::new();
309        response.body_mut().read_to_string(&mut body).await?;
310        anyhow::bail!(
311            "Failed to connect to Ollama API: {} {}",
312            response.status(),
313            body,
314        );
315    }
316}
317
318pub async fn get_models(
319    client: &dyn HttpClient,
320    api_url: &str,
321    api_key: Option<&str>,
322) -> Result<Vec<LocalModelListing>> {
323    let uri = format!("{api_url}/api/tags");
324    let request = HttpRequest::builder()
325        .method(Method::GET)
326        .uri(uri)
327        .header("Accept", "application/json")
328        .when_some(api_key, |builder, api_key| {
329            builder.header("Authorization", format!("Bearer {api_key}"))
330        })
331        .body(AsyncBody::default())?;
332
333    let mut response = client.send(request).await?;
334
335    let mut body = String::new();
336    response.body_mut().read_to_string(&mut body).await?;
337
338    anyhow::ensure!(
339        response.status().is_success(),
340        "Failed to connect to Ollama API: {} {}",
341        response.status(),
342        body,
343    );
344    let response: LocalModelsResponse =
345        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
346    Ok(response.models)
347}
348
349/// Fetch details of a model, used to determine model capabilities
350pub async fn show_model(
351    client: &dyn HttpClient,
352    api_url: &str,
353    api_key: Option<&str>,
354    model: &str,
355) -> Result<ModelShow> {
356    let uri = format!("{api_url}/api/show");
357    let request = HttpRequest::builder()
358        .method(Method::POST)
359        .uri(uri)
360        .header("Content-Type", "application/json")
361        .when_some(api_key, |builder, api_key| {
362            builder.header("Authorization", format!("Bearer {api_key}"))
363        })
364        .body(AsyncBody::from(
365            serde_json::json!({ "model": model }).to_string(),
366        ))?;
367
368    let mut response = client.send(request).await?;
369    let mut body = String::new();
370    response.body_mut().read_to_string(&mut body).await?;
371
372    anyhow::ensure!(
373        response.status().is_success(),
374        "Failed to connect to Ollama API: {} {}",
375        response.status(),
376        body,
377    );
378    let details: ModelShow = serde_json::from_str(body.as_str())?;
379    Ok(details)
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn parse_completion() {
388        let response = serde_json::json!({
389        "model": "llama3.2",
390        "created_at": "2023-12-12T14:13:43.416799Z",
391        "message": {
392            "role": "assistant",
393            "content": "Hello! How are you today?"
394        },
395        "done": true,
396        "total_duration": 5191566416u64,
397        "load_duration": 2154458,
398        "prompt_eval_count": 26,
399        "prompt_eval_duration": 383809000,
400        "eval_count": 298,
401        "eval_duration": 4799921000u64
402        });
403        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
404    }
405
406    #[test]
407    fn parse_streaming_completion() {
408        let partial = serde_json::json!({
409        "model": "llama3.2",
410        "created_at": "2023-08-04T08:52:19.385406455-07:00",
411        "message": {
412            "role": "assistant",
413            "content": "The",
414            "images": null
415        },
416        "done": false
417        });
418
419        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
420
421        let last = serde_json::json!({
422        "model": "llama3.2",
423        "created_at": "2023-08-04T19:22:45.499127Z",
424        "message": {
425            "role": "assistant",
426            "content": ""
427        },
428        "done": true,
429        "total_duration": 4883583458u64,
430        "load_duration": 1334875,
431        "prompt_eval_count": 26,
432        "prompt_eval_duration": 342546000,
433        "eval_count": 282,
434        "eval_duration": 4535599000u64
435        });
436
437        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
438    }
439
440    #[test]
441    fn parse_tool_call() {
442        let response = serde_json::json!({
443            "model": "llama3.2:3b",
444            "created_at": "2025-04-28T20:02:02.140489Z",
445            "message": {
446                "role": "assistant",
447                "content": "",
448                "tool_calls": [
449                    {
450                        "id": "call_llama3.2:3b_145155",
451                        "function": {
452                            "name": "weather",
453                            "arguments": {
454                                "city": "london",
455                            }
456                        }
457                    }
458                ]
459            },
460            "done_reason": "stop",
461            "done": true,
462            "total_duration": 2758629166u64,
463            "load_duration": 1770059875,
464            "prompt_eval_count": 147,
465            "prompt_eval_duration": 684637583,
466            "eval_count": 16,
467            "eval_duration": 302561917,
468        });
469
470        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
471        match result.message {
472            ChatMessage::Assistant {
473                content,
474                tool_calls,
475                images: _,
476                thinking,
477            } => {
478                assert!(content.is_empty());
479                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
480                assert!(thinking.is_none());
481            }
482            _ => panic!("Deserialized wrong role"),
483        }
484    }
485
486    // Backwards compatibility with Ollama versions prior to v0.12.10 November 2025
487    // This test is a copy of `parse_tool_call()` with the `id` field omitted.
488    #[test]
489    fn parse_tool_call_pre_0_12_10() {
490        let response = serde_json::json!({
491            "model": "llama3.2:3b",
492            "created_at": "2025-04-28T20:02:02.140489Z",
493            "message": {
494                "role": "assistant",
495                "content": "",
496                "tool_calls": [
497                    {
498                        "function": {
499                            "name": "weather",
500                            "arguments": {
501                                "city": "london",
502                            }
503                        }
504                    }
505                ]
506            },
507            "done_reason": "stop",
508            "done": true,
509            "total_duration": 2758629166u64,
510            "load_duration": 1770059875,
511            "prompt_eval_count": 147,
512            "prompt_eval_duration": 684637583,
513            "eval_count": 16,
514            "eval_duration": 302561917,
515        });
516
517        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
518        match result.message {
519            ChatMessage::Assistant {
520                content,
521                tool_calls: Some(tool_calls),
522                images: _,
523                thinking,
524            } => {
525                assert!(content.is_empty());
526                assert!(thinking.is_none());
527
528                // When the `Option` around `id` is removed, this test should complain
529                // and be subsequently deleted in favor of `parse_tool_call()`
530                assert!(tool_calls.first().is_some_and(|call| call.id.is_none()))
531            }
532            _ => panic!("Deserialized wrong role"),
533        }
534    }
535
536    #[test]
537    fn parse_show_model() {
538        let response = serde_json::json!({
539            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
540            "details": {
541                "parent_model": "",
542                "format": "gguf",
543                "family": "llama",
544                "families": ["llama"],
545                "parameter_size": "3.2B",
546                "quantization_level": "Q4_K_M"
547            },
548            "model_info": {
549                "general.architecture": "llama",
550                "general.basename": "Llama-3.2",
551                "general.file_type": 15,
552                "general.finetune": "Instruct",
553                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
554                "general.parameter_count": 3212749888u64,
555                "general.quantization_version": 2,
556                "general.size_label": "3B",
557                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
558                "general.type": "model",
559                "llama.attention.head_count": 24,
560                "llama.attention.head_count_kv": 8,
561                "llama.attention.key_length": 128,
562                "llama.attention.layer_norm_rms_epsilon": 0.00001,
563                "llama.attention.value_length": 128,
564                "llama.block_count": 28,
565                "llama.context_length": 131072,
566                "llama.embedding_length": 3072,
567                "llama.feed_forward_length": 8192,
568                "llama.rope.dimension_count": 128,
569                "llama.rope.freq_base": 500000,
570                "llama.vocab_size": 128256,
571                "tokenizer.ggml.bos_token_id": 128000,
572                "tokenizer.ggml.eos_token_id": 128009,
573                "tokenizer.ggml.merges": null,
574                "tokenizer.ggml.model": "gpt2",
575                "tokenizer.ggml.pre": "llama-bpe",
576                "tokenizer.ggml.token_type": null,
577                "tokenizer.ggml.tokens": null
578            },
579            "tensors": [
580                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
581                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
582            ],
583            "capabilities": ["completion", "tools"],
584            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
585        });
586
587        let result: ModelShow = serde_json::from_value(response).unwrap();
588        assert!(result.supports_tools());
589        assert!(result.capabilities.contains(&"tools".to_string()));
590        assert!(result.capabilities.contains(&"completion".to_string()));
591
592        assert_eq!(result.architecture, Some("llama".to_string()));
593        assert_eq!(result.context_length, Some(131072));
594    }
595
596    #[test]
597    fn serialize_chat_request_with_images() {
598        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
599
600        let request = ChatRequest {
601            model: "llava".to_string(),
602            messages: vec![ChatMessage::User {
603                content: "What do you see in this image?".to_string(),
604                images: Some(vec![base64_image.to_string()]),
605            }],
606            stream: false,
607            keep_alive: KeepAlive::default(),
608            options: None,
609            think: None,
610            tools: vec![],
611        };
612
613        let serialized = serde_json::to_string(&request).unwrap();
614        assert!(serialized.contains("images"));
615        assert!(serialized.contains(base64_image));
616    }
617
618    #[test]
619    fn serialize_chat_request_without_images() {
620        let request = ChatRequest {
621            model: "llama3.2".to_string(),
622            messages: vec![ChatMessage::User {
623                content: "Hello, world!".to_string(),
624                images: None,
625            }],
626            stream: false,
627            keep_alive: KeepAlive::default(),
628            options: None,
629            think: None,
630            tools: vec![],
631        };
632
633        let serialized = serde_json::to_string(&request).unwrap();
634        assert!(!serialized.contains("images"));
635    }
636
637    #[test]
638    fn test_json_format_with_images() {
639        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
640
641        let request = ChatRequest {
642            model: "llava".to_string(),
643            messages: vec![ChatMessage::User {
644                content: "What do you see?".to_string(),
645                images: Some(vec![base64_image.to_string()]),
646            }],
647            stream: false,
648            keep_alive: KeepAlive::default(),
649            options: None,
650            think: None,
651            tools: vec![],
652        };
653
654        let serialized = serde_json::to_string(&request).unwrap();
655
656        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
657        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
658        assert_eq!(message_images.len(), 1);
659        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
660    }
661}