ollama.rs

  1mod ollama_edit_prediction_delegate;
  2
  3pub use ollama_edit_prediction_delegate::OllamaEditPredictionDelegate;
  4
  5use anyhow::{Context, Result};
  6
  7use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  8use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  9use serde::{Deserialize, Serialize};
 10use serde_json::Value;
 11pub use settings::KeepAlive;
 12
 13pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 14
 15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 17pub struct Model {
 18    pub name: String,
 19    pub display_name: Option<String>,
 20    pub max_tokens: u64,
 21    pub keep_alive: Option<KeepAlive>,
 22    pub supports_tools: Option<bool>,
 23    pub supports_vision: Option<bool>,
 24    pub supports_thinking: Option<bool>,
 25}
 26
 27fn get_max_tokens(name: &str) -> u64 {
 28    /// Default context length for unknown models.
 29    const DEFAULT_TOKENS: u64 = 4096;
 30    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 31    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 32    const MAXIMUM_TOKENS: u64 = 16384;
 33
 34    match name.split(':').next().unwrap() {
 35        "granite-code" | "phi" | "tinyllama" => 2048,
 36        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 37        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 38        "codellama" | "starcoder2" => 16384,
 39        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 40        | "qwen2" | "qwen2.5-coder" => 32768,
 41        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 42        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 43        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 44        "qwen3-coder" => 256000,
 45        _ => DEFAULT_TOKENS,
 46    }
 47    .clamp(1, MAXIMUM_TOKENS)
 48}
 49
 50impl Model {
 51    pub fn new(
 52        name: &str,
 53        display_name: Option<&str>,
 54        max_tokens: Option<u64>,
 55        supports_tools: Option<bool>,
 56        supports_vision: Option<bool>,
 57        supports_thinking: Option<bool>,
 58    ) -> Self {
 59        Self {
 60            name: name.to_owned(),
 61            display_name: display_name
 62                .map(ToString::to_string)
 63                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 64            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 65            keep_alive: Some(KeepAlive::indefinite()),
 66            supports_tools,
 67            supports_vision,
 68            supports_thinking,
 69        }
 70    }
 71
 72    pub fn id(&self) -> &str {
 73        &self.name
 74    }
 75
 76    pub fn display_name(&self) -> &str {
 77        self.display_name.as_ref().unwrap_or(&self.name)
 78    }
 79
 80    pub fn max_token_count(&self) -> u64 {
 81        self.max_tokens
 82    }
 83}
 84
 85#[derive(Serialize, Deserialize, Debug)]
 86#[serde(tag = "role", rename_all = "lowercase")]
 87pub enum ChatMessage {
 88    Assistant {
 89        content: String,
 90        tool_calls: Option<Vec<OllamaToolCall>>,
 91        #[serde(skip_serializing_if = "Option::is_none")]
 92        images: Option<Vec<String>>,
 93        thinking: Option<String>,
 94    },
 95    User {
 96        content: String,
 97        #[serde(skip_serializing_if = "Option::is_none")]
 98        images: Option<Vec<String>>,
 99    },
100    System {
101        content: String,
102    },
103    Tool {
104        tool_name: String,
105        content: String,
106    },
107}
108
109#[derive(Serialize, Deserialize, Debug)]
110pub struct OllamaToolCall {
111    // TODO: Remove `Option` after most users have updated to Ollama v0.12.10,
112    // which was released on the 4th of November 2025
113    pub id: Option<String>,
114    pub function: OllamaFunctionCall,
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118pub struct OllamaFunctionCall {
119    pub name: String,
120    pub arguments: Value,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124pub struct OllamaFunctionTool {
125    pub name: String,
126    pub description: Option<String>,
127    pub parameters: Option<Value>,
128}
129
130#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
131#[serde(tag = "type", rename_all = "lowercase")]
132pub enum OllamaTool {
133    Function { function: OllamaFunctionTool },
134}
135
136#[derive(Serialize, Debug)]
137pub struct ChatRequest {
138    pub model: String,
139    pub messages: Vec<ChatMessage>,
140    pub stream: bool,
141    pub keep_alive: KeepAlive,
142    pub options: Option<ChatOptions>,
143    pub tools: Vec<OllamaTool>,
144    pub think: Option<bool>,
145}
146
147// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
148#[derive(Serialize, Default, Debug)]
149pub struct ChatOptions {
150    pub num_ctx: Option<u64>,
151    pub num_predict: Option<isize>,
152    pub stop: Option<Vec<String>>,
153    pub temperature: Option<f32>,
154    pub top_p: Option<f32>,
155}
156
157#[derive(Deserialize, Debug)]
158pub struct ChatResponseDelta {
159    pub model: String,
160    pub created_at: String,
161    pub message: ChatMessage,
162    pub done_reason: Option<String>,
163    pub done: bool,
164    pub prompt_eval_count: Option<u64>,
165    pub eval_count: Option<u64>,
166}
167
168#[derive(Serialize, Deserialize)]
169pub struct LocalModelsResponse {
170    pub models: Vec<LocalModelListing>,
171}
172
173#[derive(Serialize, Deserialize)]
174pub struct LocalModelListing {
175    pub name: String,
176    pub modified_at: String,
177    pub size: u64,
178    pub digest: String,
179    pub details: ModelDetails,
180}
181
182#[derive(Serialize, Deserialize)]
183pub struct LocalModel {
184    pub modelfile: String,
185    pub parameters: String,
186    pub template: String,
187    pub details: ModelDetails,
188}
189
190#[derive(Serialize, Deserialize)]
191pub struct ModelDetails {
192    pub format: String,
193    pub family: String,
194    pub families: Option<Vec<String>>,
195    pub parameter_size: String,
196    pub quantization_level: String,
197}
198
199#[derive(Debug)]
200pub struct ModelShow {
201    pub capabilities: Vec<String>,
202    pub context_length: Option<u64>,
203    pub architecture: Option<String>,
204}
205
206impl<'de> Deserialize<'de> for ModelShow {
207    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
208    where
209        D: serde::Deserializer<'de>,
210    {
211        use serde::de::{self, MapAccess, Visitor};
212        use std::fmt;
213
214        struct ModelShowVisitor;
215
216        impl<'de> Visitor<'de> for ModelShowVisitor {
217            type Value = ModelShow;
218
219            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
220                formatter.write_str("a ModelShow object")
221            }
222
223            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
224            where
225                A: MapAccess<'de>,
226            {
227                let mut capabilities: Vec<String> = Vec::new();
228                let mut architecture: Option<String> = None;
229                let mut context_length: Option<u64> = None;
230
231                while let Some(key) = map.next_key::<String>()? {
232                    match key.as_str() {
233                        "capabilities" => {
234                            capabilities = map.next_value()?;
235                        }
236                        "model_info" => {
237                            let model_info: Value = map.next_value()?;
238                            if let Value::Object(obj) = model_info {
239                                architecture = obj
240                                    .get("general.architecture")
241                                    .and_then(|v| v.as_str())
242                                    .map(String::from);
243
244                                if let Some(arch) = &architecture {
245                                    context_length = obj
246                                        .get(&format!("{}.context_length", arch))
247                                        .and_then(|v| v.as_u64());
248                                }
249                            }
250                        }
251                        _ => {
252                            let _: de::IgnoredAny = map.next_value()?;
253                        }
254                    }
255                }
256
257                Ok(ModelShow {
258                    capabilities,
259                    context_length,
260                    architecture,
261                })
262            }
263        }
264
265        deserializer.deserialize_map(ModelShowVisitor)
266    }
267}
268
269impl ModelShow {
270    pub fn supports_tools(&self) -> bool {
271        // .contains expects &String, which would require an additional allocation
272        self.capabilities.iter().any(|v| v == "tools")
273    }
274
275    pub fn supports_vision(&self) -> bool {
276        self.capabilities.iter().any(|v| v == "vision")
277    }
278
279    pub fn supports_thinking(&self) -> bool {
280        self.capabilities.iter().any(|v| v == "thinking")
281    }
282}
283
284pub async fn stream_chat_completion(
285    client: &dyn HttpClient,
286    api_url: &str,
287    api_key: Option<&str>,
288    request: ChatRequest,
289) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
290    let uri = format!("{api_url}/api/chat");
291    let request = HttpRequest::builder()
292        .method(Method::POST)
293        .uri(uri)
294        .header("Content-Type", "application/json")
295        .when_some(api_key, |builder, api_key| {
296            builder.header("Authorization", format!("Bearer {api_key}"))
297        })
298        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
299
300    let mut response = client.send(request).await?;
301    if response.status().is_success() {
302        let reader = BufReader::new(response.into_body());
303
304        Ok(reader
305            .lines()
306            .map(|line| match line {
307                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
308                Err(e) => Err(e.into()),
309            })
310            .boxed())
311    } else {
312        let mut body = String::new();
313        response.body_mut().read_to_string(&mut body).await?;
314        anyhow::bail!(
315            "Failed to connect to Ollama API: {} {}",
316            response.status(),
317            body,
318        );
319    }
320}
321
322pub async fn get_models(
323    client: &dyn HttpClient,
324    api_url: &str,
325    api_key: Option<&str>,
326) -> Result<Vec<LocalModelListing>> {
327    let uri = format!("{api_url}/api/tags");
328    let request = HttpRequest::builder()
329        .method(Method::GET)
330        .uri(uri)
331        .header("Accept", "application/json")
332        .when_some(api_key, |builder, api_key| {
333            builder.header("Authorization", format!("Bearer {api_key}"))
334        })
335        .body(AsyncBody::default())?;
336
337    let mut response = client.send(request).await?;
338
339    let mut body = String::new();
340    response.body_mut().read_to_string(&mut body).await?;
341
342    anyhow::ensure!(
343        response.status().is_success(),
344        "Failed to connect to Ollama API: {} {}",
345        response.status(),
346        body,
347    );
348    let response: LocalModelsResponse =
349        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
350    Ok(response.models)
351}
352
353/// Fetch details of a model, used to determine model capabilities
354pub async fn show_model(
355    client: &dyn HttpClient,
356    api_url: &str,
357    api_key: Option<&str>,
358    model: &str,
359) -> Result<ModelShow> {
360    let uri = format!("{api_url}/api/show");
361    let request = HttpRequest::builder()
362        .method(Method::POST)
363        .uri(uri)
364        .header("Content-Type", "application/json")
365        .when_some(api_key, |builder, api_key| {
366            builder.header("Authorization", format!("Bearer {api_key}"))
367        })
368        .body(AsyncBody::from(
369            serde_json::json!({ "model": model }).to_string(),
370        ))?;
371
372    let mut response = client.send(request).await?;
373    let mut body = String::new();
374    response.body_mut().read_to_string(&mut body).await?;
375
376    anyhow::ensure!(
377        response.status().is_success(),
378        "Failed to connect to Ollama API: {} {}",
379        response.status(),
380        body,
381    );
382    let details: ModelShow = serde_json::from_str(body.as_str())?;
383    Ok(details)
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn parse_completion() {
392        let response = serde_json::json!({
393        "model": "llama3.2",
394        "created_at": "2023-12-12T14:13:43.416799Z",
395        "message": {
396            "role": "assistant",
397            "content": "Hello! How are you today?"
398        },
399        "done": true,
400        "total_duration": 5191566416u64,
401        "load_duration": 2154458,
402        "prompt_eval_count": 26,
403        "prompt_eval_duration": 383809000,
404        "eval_count": 298,
405        "eval_duration": 4799921000u64
406        });
407        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
408    }
409
410    #[test]
411    fn parse_streaming_completion() {
412        let partial = serde_json::json!({
413        "model": "llama3.2",
414        "created_at": "2023-08-04T08:52:19.385406455-07:00",
415        "message": {
416            "role": "assistant",
417            "content": "The",
418            "images": null
419        },
420        "done": false
421        });
422
423        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
424
425        let last = serde_json::json!({
426        "model": "llama3.2",
427        "created_at": "2023-08-04T19:22:45.499127Z",
428        "message": {
429            "role": "assistant",
430            "content": ""
431        },
432        "done": true,
433        "total_duration": 4883583458u64,
434        "load_duration": 1334875,
435        "prompt_eval_count": 26,
436        "prompt_eval_duration": 342546000,
437        "eval_count": 282,
438        "eval_duration": 4535599000u64
439        });
440
441        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
442    }
443
444    #[test]
445    fn parse_tool_call() {
446        let response = serde_json::json!({
447            "model": "llama3.2:3b",
448            "created_at": "2025-04-28T20:02:02.140489Z",
449            "message": {
450                "role": "assistant",
451                "content": "",
452                "tool_calls": [
453                    {
454                        "id": "call_llama3.2:3b_145155",
455                        "function": {
456                            "name": "weather",
457                            "arguments": {
458                                "city": "london",
459                            }
460                        }
461                    }
462                ]
463            },
464            "done_reason": "stop",
465            "done": true,
466            "total_duration": 2758629166u64,
467            "load_duration": 1770059875,
468            "prompt_eval_count": 147,
469            "prompt_eval_duration": 684637583,
470            "eval_count": 16,
471            "eval_duration": 302561917,
472        });
473
474        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
475        match result.message {
476            ChatMessage::Assistant {
477                content,
478                tool_calls,
479                images: _,
480                thinking,
481            } => {
482                assert!(content.is_empty());
483                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
484                assert!(thinking.is_none());
485            }
486            _ => panic!("Deserialized wrong role"),
487        }
488    }
489
490    // Backwards compatibility with Ollama versions prior to v0.12.10 November 2025
491    // This test is a copy of `parse_tool_call()` with the `id` field omitted.
492    #[test]
493    fn parse_tool_call_pre_0_12_10() {
494        let response = serde_json::json!({
495            "model": "llama3.2:3b",
496            "created_at": "2025-04-28T20:02:02.140489Z",
497            "message": {
498                "role": "assistant",
499                "content": "",
500                "tool_calls": [
501                    {
502                        "function": {
503                            "name": "weather",
504                            "arguments": {
505                                "city": "london",
506                            }
507                        }
508                    }
509                ]
510            },
511            "done_reason": "stop",
512            "done": true,
513            "total_duration": 2758629166u64,
514            "load_duration": 1770059875,
515            "prompt_eval_count": 147,
516            "prompt_eval_duration": 684637583,
517            "eval_count": 16,
518            "eval_duration": 302561917,
519        });
520
521        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
522        match result.message {
523            ChatMessage::Assistant {
524                content,
525                tool_calls: Some(tool_calls),
526                images: _,
527                thinking,
528            } => {
529                assert!(content.is_empty());
530                assert!(thinking.is_none());
531
532                // When the `Option` around `id` is removed, this test should complain
533                // and be subsequently deleted in favor of `parse_tool_call()`
534                assert!(tool_calls.first().is_some_and(|call| call.id.is_none()))
535            }
536            _ => panic!("Deserialized wrong role"),
537        }
538    }
539
540    #[test]
541    fn parse_show_model() {
542        let response = serde_json::json!({
543            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
544            "details": {
545                "parent_model": "",
546                "format": "gguf",
547                "family": "llama",
548                "families": ["llama"],
549                "parameter_size": "3.2B",
550                "quantization_level": "Q4_K_M"
551            },
552            "model_info": {
553                "general.architecture": "llama",
554                "general.basename": "Llama-3.2",
555                "general.file_type": 15,
556                "general.finetune": "Instruct",
557                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
558                "general.parameter_count": 3212749888u64,
559                "general.quantization_version": 2,
560                "general.size_label": "3B",
561                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
562                "general.type": "model",
563                "llama.attention.head_count": 24,
564                "llama.attention.head_count_kv": 8,
565                "llama.attention.key_length": 128,
566                "llama.attention.layer_norm_rms_epsilon": 0.00001,
567                "llama.attention.value_length": 128,
568                "llama.block_count": 28,
569                "llama.context_length": 131072,
570                "llama.embedding_length": 3072,
571                "llama.feed_forward_length": 8192,
572                "llama.rope.dimension_count": 128,
573                "llama.rope.freq_base": 500000,
574                "llama.vocab_size": 128256,
575                "tokenizer.ggml.bos_token_id": 128000,
576                "tokenizer.ggml.eos_token_id": 128009,
577                "tokenizer.ggml.merges": null,
578                "tokenizer.ggml.model": "gpt2",
579                "tokenizer.ggml.pre": "llama-bpe",
580                "tokenizer.ggml.token_type": null,
581                "tokenizer.ggml.tokens": null
582            },
583            "tensors": [
584                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
585                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
586            ],
587            "capabilities": ["completion", "tools"],
588            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
589        });
590
591        let result: ModelShow = serde_json::from_value(response).unwrap();
592        assert!(result.supports_tools());
593        assert!(result.capabilities.contains(&"tools".to_string()));
594        assert!(result.capabilities.contains(&"completion".to_string()));
595
596        assert_eq!(result.architecture, Some("llama".to_string()));
597        assert_eq!(result.context_length, Some(131072));
598    }
599
600    #[test]
601    fn serialize_chat_request_with_images() {
602        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
603
604        let request = ChatRequest {
605            model: "llava".to_string(),
606            messages: vec![ChatMessage::User {
607                content: "What do you see in this image?".to_string(),
608                images: Some(vec![base64_image.to_string()]),
609            }],
610            stream: false,
611            keep_alive: KeepAlive::default(),
612            options: None,
613            think: None,
614            tools: vec![],
615        };
616
617        let serialized = serde_json::to_string(&request).unwrap();
618        assert!(serialized.contains("images"));
619        assert!(serialized.contains(base64_image));
620    }
621
622    #[test]
623    fn serialize_chat_request_without_images() {
624        let request = ChatRequest {
625            model: "llama3.2".to_string(),
626            messages: vec![ChatMessage::User {
627                content: "Hello, world!".to_string(),
628                images: None,
629            }],
630            stream: false,
631            keep_alive: KeepAlive::default(),
632            options: None,
633            think: None,
634            tools: vec![],
635        };
636
637        let serialized = serde_json::to_string(&request).unwrap();
638        assert!(!serialized.contains("images"));
639    }
640
641    #[test]
642    fn test_json_format_with_images() {
643        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
644
645        let request = ChatRequest {
646            model: "llava".to_string(),
647            messages: vec![ChatMessage::User {
648                content: "What do you see?".to_string(),
649                images: Some(vec![base64_image.to_string()]),
650            }],
651            stream: false,
652            keep_alive: KeepAlive::default(),
653            options: None,
654            think: None,
655            tools: vec![],
656        };
657
658        let serialized = serde_json::to_string(&request).unwrap();
659
660        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
661        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
662        assert_eq!(message_images.len(), 1);
663        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
664    }
665}