ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::KeepAlive;
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 12pub struct Model {
 13    pub name: String,
 14    pub display_name: Option<String>,
 15    pub max_tokens: u64,
 16    pub keep_alive: Option<KeepAlive>,
 17    pub supports_tools: Option<bool>,
 18    pub supports_vision: Option<bool>,
 19    pub supports_thinking: Option<bool>,
 20}
 21
 22fn get_max_tokens(name: &str) -> u64 {
 23    /// Default context length for unknown models.
 24    const DEFAULT_TOKENS: u64 = 4096;
 25    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 26    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 27    const MAXIMUM_TOKENS: u64 = 16384;
 28
 29    match name.split(':').next().unwrap() {
 30        "granite-code" | "phi" | "tinyllama" => 2048,
 31        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 32        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 33        "codellama" | "starcoder2" => 16384,
 34        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 35        | "qwen2" | "qwen2.5-coder" => 32768,
 36        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 37        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 38        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 39        "qwen3-coder" => 256000,
 40        _ => DEFAULT_TOKENS,
 41    }
 42    .clamp(1, MAXIMUM_TOKENS)
 43}
 44
 45impl Model {
 46    pub fn new(
 47        name: &str,
 48        display_name: Option<&str>,
 49        max_tokens: Option<u64>,
 50        supports_tools: Option<bool>,
 51        supports_vision: Option<bool>,
 52        supports_thinking: Option<bool>,
 53    ) -> Self {
 54        Self {
 55            name: name.to_owned(),
 56            display_name: display_name
 57                .map(ToString::to_string)
 58                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 59            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 60            keep_alive: Some(KeepAlive::indefinite()),
 61            supports_tools,
 62            supports_vision,
 63            supports_thinking,
 64        }
 65    }
 66
 67    pub fn id(&self) -> &str {
 68        &self.name
 69    }
 70
 71    pub fn display_name(&self) -> &str {
 72        self.display_name.as_ref().unwrap_or(&self.name)
 73    }
 74
 75    pub const fn max_token_count(&self) -> u64 {
 76        self.max_tokens
 77    }
 78}
 79
 80#[derive(Serialize, Deserialize, Debug)]
 81#[serde(tag = "role", rename_all = "lowercase")]
 82pub enum ChatMessage {
 83    Assistant {
 84        content: String,
 85        tool_calls: Option<Vec<OllamaToolCall>>,
 86        #[serde(skip_serializing_if = "Option::is_none")]
 87        images: Option<Vec<String>>,
 88        thinking: Option<String>,
 89    },
 90    User {
 91        content: String,
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        images: Option<Vec<String>>,
 94    },
 95    System {
 96        content: String,
 97    },
 98    Tool {
 99        tool_name: String,
100        content: String,
101    },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105#[serde(rename_all = "lowercase")]
106pub enum OllamaToolCall {
107    Function(OllamaFunctionCall),
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct OllamaFunctionCall {
112    pub name: String,
113    pub arguments: Value,
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionTool {
118    pub name: String,
119    pub description: Option<String>,
120    pub parameters: Option<Value>,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum OllamaTool {
126    Function { function: OllamaFunctionTool },
127}
128
129#[derive(Serialize, Debug)]
130pub struct ChatRequest {
131    pub model: String,
132    pub messages: Vec<ChatMessage>,
133    pub stream: bool,
134    pub keep_alive: KeepAlive,
135    pub options: Option<ChatOptions>,
136    pub tools: Vec<OllamaTool>,
137    pub think: Option<bool>,
138}
139
140// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
141#[derive(Serialize, Default, Debug)]
142pub struct ChatOptions {
143    pub num_ctx: Option<u64>,
144    pub num_predict: Option<isize>,
145    pub stop: Option<Vec<String>>,
146    pub temperature: Option<f32>,
147    pub top_p: Option<f32>,
148}
149
150#[derive(Deserialize, Debug)]
151pub struct ChatResponseDelta {
152    pub model: String,
153    pub created_at: String,
154    pub message: ChatMessage,
155    pub done_reason: Option<String>,
156    pub done: bool,
157    pub prompt_eval_count: Option<u64>,
158    pub eval_count: Option<u64>,
159}
160
161#[derive(Serialize, Deserialize)]
162pub struct LocalModelsResponse {
163    pub models: Vec<LocalModelListing>,
164}
165
166#[derive(Serialize, Deserialize)]
167pub struct LocalModelListing {
168    pub name: String,
169    pub modified_at: String,
170    pub size: u64,
171    pub digest: String,
172    pub details: ModelDetails,
173}
174
175#[derive(Serialize, Deserialize)]
176pub struct LocalModel {
177    pub modelfile: String,
178    pub parameters: String,
179    pub template: String,
180    pub details: ModelDetails,
181}
182
183#[derive(Serialize, Deserialize)]
184pub struct ModelDetails {
185    pub format: String,
186    pub family: String,
187    pub families: Option<Vec<String>>,
188    pub parameter_size: String,
189    pub quantization_level: String,
190}
191
192#[derive(Debug)]
193pub struct ModelShow {
194    pub capabilities: Vec<String>,
195    pub context_length: Option<u64>,
196    pub architecture: Option<String>,
197}
198
199impl<'de> Deserialize<'de> for ModelShow {
200    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
201    where
202        D: serde::Deserializer<'de>,
203    {
204        use serde::de::{self, MapAccess, Visitor};
205        use std::fmt;
206
207        struct ModelShowVisitor;
208
209        impl<'de> Visitor<'de> for ModelShowVisitor {
210            type Value = ModelShow;
211
212            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
213                formatter.write_str("a ModelShow object")
214            }
215
216            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
217            where
218                A: MapAccess<'de>,
219            {
220                let mut capabilities: Vec<String> = Vec::new();
221                let mut architecture: Option<String> = None;
222                let mut context_length: Option<u64> = None;
223
224                while let Some(key) = map.next_key::<String>()? {
225                    match key.as_str() {
226                        "capabilities" => {
227                            capabilities = map.next_value()?;
228                        }
229                        "model_info" => {
230                            let model_info: Value = map.next_value()?;
231                            if let Value::Object(obj) = model_info {
232                                architecture = obj
233                                    .get("general.architecture")
234                                    .and_then(|v| v.as_str())
235                                    .map(String::from);
236
237                                if let Some(arch) = &architecture {
238                                    context_length = obj
239                                        .get(&format!("{}.context_length", arch))
240                                        .and_then(|v| v.as_u64());
241                                }
242                            }
243                        }
244                        _ => {
245                            let _: de::IgnoredAny = map.next_value()?;
246                        }
247                    }
248                }
249
250                Ok(ModelShow {
251                    capabilities,
252                    context_length,
253                    architecture,
254                })
255            }
256        }
257
258        deserializer.deserialize_map(ModelShowVisitor)
259    }
260}
261
262impl ModelShow {
263    pub fn supports_tools(&self) -> bool {
264        // .contains expects &String, which would require an additional allocation
265        self.capabilities.iter().any(|v| v == "tools")
266    }
267
268    pub fn supports_vision(&self) -> bool {
269        self.capabilities.iter().any(|v| v == "vision")
270    }
271
272    pub fn supports_thinking(&self) -> bool {
273        self.capabilities.iter().any(|v| v == "thinking")
274    }
275}
276
277pub async fn stream_chat_completion(
278    client: &dyn HttpClient,
279    api_url: &str,
280    api_key: Option<&str>,
281    request: ChatRequest,
282) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
283    let uri = format!("{api_url}/api/chat");
284    let request = HttpRequest::builder()
285        .method(Method::POST)
286        .uri(uri)
287        .header("Content-Type", "application/json")
288        .when_some(api_key, |builder, api_key| {
289            builder.header("Authorization", format!("Bearer {api_key}"))
290        })
291        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
292
293    let mut response = client.send(request).await?;
294    if response.status().is_success() {
295        let reader = BufReader::new(response.into_body());
296
297        Ok(reader
298            .lines()
299            .map(|line| match line {
300                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
301                Err(e) => Err(e.into()),
302            })
303            .boxed())
304    } else {
305        let mut body = String::new();
306        response.body_mut().read_to_string(&mut body).await?;
307        anyhow::bail!(
308            "Failed to connect to Ollama API: {} {}",
309            response.status(),
310            body,
311        );
312    }
313}
314
315pub async fn get_models(
316    client: &dyn HttpClient,
317    api_url: &str,
318    api_key: Option<&str>,
319) -> Result<Vec<LocalModelListing>> {
320    let uri = format!("{api_url}/api/tags");
321    let request = HttpRequest::builder()
322        .method(Method::GET)
323        .uri(uri)
324        .header("Accept", "application/json")
325        .when_some(api_key, |builder, api_key| {
326            builder.header("Authorization", format!("Bearer {api_key}"))
327        })
328        .body(AsyncBody::default())?;
329
330    let mut response = client.send(request).await?;
331
332    let mut body = String::new();
333    response.body_mut().read_to_string(&mut body).await?;
334
335    anyhow::ensure!(
336        response.status().is_success(),
337        "Failed to connect to Ollama API: {} {}",
338        response.status(),
339        body,
340    );
341    let response: LocalModelsResponse =
342        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
343    Ok(response.models)
344}
345
346/// Fetch details of a model, used to determine model capabilities
347pub async fn show_model(
348    client: &dyn HttpClient,
349    api_url: &str,
350    api_key: Option<&str>,
351    model: &str,
352) -> Result<ModelShow> {
353    let uri = format!("{api_url}/api/show");
354    let request = HttpRequest::builder()
355        .method(Method::POST)
356        .uri(uri)
357        .header("Content-Type", "application/json")
358        .when_some(api_key, |builder, api_key| {
359            builder.header("Authorization", format!("Bearer {api_key}"))
360        })
361        .body(AsyncBody::from(
362            serde_json::json!({ "model": model }).to_string(),
363        ))?;
364
365    let mut response = client.send(request).await?;
366    let mut body = String::new();
367    response.body_mut().read_to_string(&mut body).await?;
368
369    anyhow::ensure!(
370        response.status().is_success(),
371        "Failed to connect to Ollama API: {} {}",
372        response.status(),
373        body,
374    );
375    let details: ModelShow = serde_json::from_str(body.as_str())?;
376    Ok(details)
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn parse_completion() {
385        let response = serde_json::json!({
386        "model": "llama3.2",
387        "created_at": "2023-12-12T14:13:43.416799Z",
388        "message": {
389            "role": "assistant",
390            "content": "Hello! How are you today?"
391        },
392        "done": true,
393        "total_duration": 5191566416u64,
394        "load_duration": 2154458,
395        "prompt_eval_count": 26,
396        "prompt_eval_duration": 383809000,
397        "eval_count": 298,
398        "eval_duration": 4799921000u64
399        });
400        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
401    }
402
403    #[test]
404    fn parse_streaming_completion() {
405        let partial = serde_json::json!({
406        "model": "llama3.2",
407        "created_at": "2023-08-04T08:52:19.385406455-07:00",
408        "message": {
409            "role": "assistant",
410            "content": "The",
411            "images": null
412        },
413        "done": false
414        });
415
416        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
417
418        let last = serde_json::json!({
419        "model": "llama3.2",
420        "created_at": "2023-08-04T19:22:45.499127Z",
421        "message": {
422            "role": "assistant",
423            "content": ""
424        },
425        "done": true,
426        "total_duration": 4883583458u64,
427        "load_duration": 1334875,
428        "prompt_eval_count": 26,
429        "prompt_eval_duration": 342546000,
430        "eval_count": 282,
431        "eval_duration": 4535599000u64
432        });
433
434        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
435    }
436
437    #[test]
438    fn parse_tool_call() {
439        let response = serde_json::json!({
440            "model": "llama3.2:3b",
441            "created_at": "2025-04-28T20:02:02.140489Z",
442            "message": {
443                "role": "assistant",
444                "content": "",
445                "tool_calls": [
446                    {
447                        "function": {
448                            "name": "weather",
449                            "arguments": {
450                                "city": "london",
451                            }
452                        }
453                    }
454                ]
455            },
456            "done_reason": "stop",
457            "done": true,
458            "total_duration": 2758629166u64,
459            "load_duration": 1770059875,
460            "prompt_eval_count": 147,
461            "prompt_eval_duration": 684637583,
462            "eval_count": 16,
463            "eval_duration": 302561917,
464        });
465
466        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
467        match result.message {
468            ChatMessage::Assistant {
469                content,
470                tool_calls,
471                images: _,
472                thinking,
473            } => {
474                assert!(content.is_empty());
475                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
476                assert!(thinking.is_none());
477            }
478            _ => panic!("Deserialized wrong role"),
479        }
480    }
481
482    #[test]
483    fn parse_show_model() {
484        let response = serde_json::json!({
485            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
486            "details": {
487                "parent_model": "",
488                "format": "gguf",
489                "family": "llama",
490                "families": ["llama"],
491                "parameter_size": "3.2B",
492                "quantization_level": "Q4_K_M"
493            },
494            "model_info": {
495                "general.architecture": "llama",
496                "general.basename": "Llama-3.2",
497                "general.file_type": 15,
498                "general.finetune": "Instruct",
499                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
500                "general.parameter_count": 3212749888u64,
501                "general.quantization_version": 2,
502                "general.size_label": "3B",
503                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
504                "general.type": "model",
505                "llama.attention.head_count": 24,
506                "llama.attention.head_count_kv": 8,
507                "llama.attention.key_length": 128,
508                "llama.attention.layer_norm_rms_epsilon": 0.00001,
509                "llama.attention.value_length": 128,
510                "llama.block_count": 28,
511                "llama.context_length": 131072,
512                "llama.embedding_length": 3072,
513                "llama.feed_forward_length": 8192,
514                "llama.rope.dimension_count": 128,
515                "llama.rope.freq_base": 500000,
516                "llama.vocab_size": 128256,
517                "tokenizer.ggml.bos_token_id": 128000,
518                "tokenizer.ggml.eos_token_id": 128009,
519                "tokenizer.ggml.merges": null,
520                "tokenizer.ggml.model": "gpt2",
521                "tokenizer.ggml.pre": "llama-bpe",
522                "tokenizer.ggml.token_type": null,
523                "tokenizer.ggml.tokens": null
524            },
525            "tensors": [
526                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
527                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
528            ],
529            "capabilities": ["completion", "tools"],
530            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
531        });
532
533        let result: ModelShow = serde_json::from_value(response).unwrap();
534        assert!(result.supports_tools());
535        assert!(result.capabilities.contains(&"tools".to_string()));
536        assert!(result.capabilities.contains(&"completion".to_string()));
537
538        assert_eq!(result.architecture, Some("llama".to_string()));
539        assert_eq!(result.context_length, Some(131072));
540    }
541
542    #[test]
543    fn serialize_chat_request_with_images() {
544        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
545
546        let request = ChatRequest {
547            model: "llava".to_string(),
548            messages: vec![ChatMessage::User {
549                content: "What do you see in this image?".to_string(),
550                images: Some(vec![base64_image.to_string()]),
551            }],
552            stream: false,
553            keep_alive: KeepAlive::default(),
554            options: None,
555            think: None,
556            tools: vec![],
557        };
558
559        let serialized = serde_json::to_string(&request).unwrap();
560        assert!(serialized.contains("images"));
561        assert!(serialized.contains(base64_image));
562    }
563
564    #[test]
565    fn serialize_chat_request_without_images() {
566        let request = ChatRequest {
567            model: "llama3.2".to_string(),
568            messages: vec![ChatMessage::User {
569                content: "Hello, world!".to_string(),
570                images: None,
571            }],
572            stream: false,
573            keep_alive: KeepAlive::default(),
574            options: None,
575            think: None,
576            tools: vec![],
577        };
578
579        let serialized = serde_json::to_string(&request).unwrap();
580        assert!(!serialized.contains("images"));
581    }
582
583    #[test]
584    fn test_json_format_with_images() {
585        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
586
587        let request = ChatRequest {
588            model: "llava".to_string(),
589            messages: vec![ChatMessage::User {
590                content: "What do you see?".to_string(),
591                images: Some(vec![base64_image.to_string()]),
592            }],
593            stream: false,
594            keep_alive: KeepAlive::default(),
595            options: None,
596            think: None,
597            tools: vec![],
598        };
599
600        let serialized = serde_json::to_string(&request).unwrap();
601
602        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
603        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
604        assert_eq!(message_images.len(), 1);
605        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
606    }
607}