ollama.rs

  1use anyhow::{Context, Result};
  2
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::KeepAlive;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(_name: &str) -> u64 {
 24    const DEFAULT_TOKENS: u64 = 4096;
 25    DEFAULT_TOKENS
 26}
 27
 28impl Model {
 29    pub fn new(
 30        name: &str,
 31        display_name: Option<&str>,
 32        max_tokens: Option<u64>,
 33        supports_tools: Option<bool>,
 34        supports_vision: Option<bool>,
 35        supports_thinking: Option<bool>,
 36    ) -> Self {
 37        Self {
 38            name: name.to_owned(),
 39            display_name: display_name
 40                .map(ToString::to_string)
 41                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 42            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 43            keep_alive: Some(KeepAlive::indefinite()),
 44            supports_tools,
 45            supports_vision,
 46            supports_thinking,
 47        }
 48    }
 49
 50    pub fn id(&self) -> &str {
 51        &self.name
 52    }
 53
 54    pub fn display_name(&self) -> &str {
 55        self.display_name.as_ref().unwrap_or(&self.name)
 56    }
 57
 58    pub fn max_token_count(&self) -> u64 {
 59        self.max_tokens
 60    }
 61}
 62
 63#[derive(Serialize, Deserialize, Debug)]
 64#[serde(tag = "role", rename_all = "lowercase")]
 65pub enum ChatMessage {
 66    Assistant {
 67        content: String,
 68        tool_calls: Option<Vec<OllamaToolCall>>,
 69        #[serde(skip_serializing_if = "Option::is_none")]
 70        images: Option<Vec<String>>,
 71        thinking: Option<String>,
 72    },
 73    User {
 74        content: String,
 75        #[serde(skip_serializing_if = "Option::is_none")]
 76        images: Option<Vec<String>>,
 77    },
 78    System {
 79        content: String,
 80    },
 81    Tool {
 82        tool_name: String,
 83        content: String,
 84    },
 85}
 86
 87#[derive(Serialize, Deserialize, Debug)]
 88pub struct OllamaToolCall {
 89    // TODO: Remove `Option` after most users have updated to Ollama v0.12.10,
 90    // which was released on the 4th of November 2025
 91    pub id: Option<String>,
 92    pub function: OllamaFunctionCall,
 93}
 94
 95#[derive(Serialize, Deserialize, Debug)]
 96pub struct OllamaFunctionCall {
 97    pub name: String,
 98    pub arguments: Value,
 99}
100
101#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
102pub struct OllamaFunctionTool {
103    pub name: String,
104    pub description: Option<String>,
105    pub parameters: Option<Value>,
106}
107
108#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
109#[serde(tag = "type", rename_all = "lowercase")]
110pub enum OllamaTool {
111    Function { function: OllamaFunctionTool },
112}
113
114#[derive(Serialize, Debug)]
115pub struct ChatRequest {
116    pub model: String,
117    pub messages: Vec<ChatMessage>,
118    pub stream: bool,
119    pub keep_alive: KeepAlive,
120    pub options: Option<ChatOptions>,
121    pub tools: Vec<OllamaTool>,
122    pub think: Option<bool>,
123}
124
125// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
126#[derive(Serialize, Default, Debug)]
127pub struct ChatOptions {
128    pub num_ctx: Option<u64>,
129    pub num_predict: Option<isize>,
130    pub stop: Option<Vec<String>>,
131    pub temperature: Option<f32>,
132    pub top_p: Option<f32>,
133}
134
135#[derive(Deserialize, Debug)]
136pub struct ChatResponseDelta {
137    pub model: String,
138    pub created_at: String,
139    pub message: ChatMessage,
140    pub done_reason: Option<String>,
141    pub done: bool,
142    pub prompt_eval_count: Option<u64>,
143    pub eval_count: Option<u64>,
144}
145
146#[derive(Serialize, Deserialize)]
147pub struct LocalModelsResponse {
148    pub models: Vec<LocalModelListing>,
149}
150
151#[derive(Serialize, Deserialize)]
152pub struct LocalModelListing {
153    pub name: String,
154    pub modified_at: String,
155    pub size: u64,
156    pub digest: String,
157    pub details: ModelDetails,
158}
159
160#[derive(Serialize, Deserialize)]
161pub struct LocalModel {
162    pub modelfile: String,
163    pub parameters: String,
164    pub template: String,
165    pub details: ModelDetails,
166}
167
168#[derive(Serialize, Deserialize)]
169pub struct ModelDetails {
170    pub format: String,
171    pub family: String,
172    pub families: Option<Vec<String>>,
173    pub parameter_size: String,
174    pub quantization_level: String,
175}
176
177#[derive(Debug)]
178pub struct ModelShow {
179    pub capabilities: Vec<String>,
180    pub context_length: Option<u64>,
181    pub architecture: Option<String>,
182}
183
184impl<'de> Deserialize<'de> for ModelShow {
185    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186    where
187        D: serde::Deserializer<'de>,
188    {
189        use serde::de::{self, MapAccess, Visitor};
190        use std::fmt;
191
192        struct ModelShowVisitor;
193
194        impl<'de> Visitor<'de> for ModelShowVisitor {
195            type Value = ModelShow;
196
197            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
198                formatter.write_str("a ModelShow object")
199            }
200
201            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
202            where
203                A: MapAccess<'de>,
204            {
205                let mut capabilities: Vec<String> = Vec::new();
206                let mut architecture: Option<String> = None;
207                let mut context_length: Option<u64> = None;
208
209                while let Some(key) = map.next_key::<String>()? {
210                    match key.as_str() {
211                        "capabilities" => {
212                            capabilities = map.next_value()?;
213                        }
214                        "model_info" => {
215                            let model_info: Value = map.next_value()?;
216                            if let Value::Object(obj) = model_info {
217                                architecture = obj
218                                    .get("general.architecture")
219                                    .and_then(|v| v.as_str())
220                                    .map(String::from);
221
222                                if let Some(arch) = &architecture {
223                                    context_length = obj
224                                        .get(&format!("{}.context_length", arch))
225                                        .and_then(|v| v.as_u64());
226                                }
227                            }
228                        }
229                        _ => {
230                            let _: de::IgnoredAny = map.next_value()?;
231                        }
232                    }
233                }
234
235                Ok(ModelShow {
236                    capabilities,
237                    context_length,
238                    architecture,
239                })
240            }
241        }
242
243        deserializer.deserialize_map(ModelShowVisitor)
244    }
245}
246
247impl ModelShow {
248    pub fn supports_tools(&self) -> bool {
249        // .contains expects &String, which would require an additional allocation
250        self.capabilities.iter().any(|v| v == "tools")
251    }
252
253    pub fn supports_vision(&self) -> bool {
254        self.capabilities.iter().any(|v| v == "vision")
255    }
256
257    pub fn supports_thinking(&self) -> bool {
258        self.capabilities.iter().any(|v| v == "thinking")
259    }
260}
261
262pub async fn stream_chat_completion(
263    client: &dyn HttpClient,
264    api_url: &str,
265    api_key: Option<&str>,
266    request: ChatRequest,
267) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
268    let uri = format!("{api_url}/api/chat");
269    let request = HttpRequest::builder()
270        .method(Method::POST)
271        .uri(uri)
272        .header("Content-Type", "application/json")
273        .when_some(api_key, |builder, api_key| {
274            builder.header("Authorization", format!("Bearer {api_key}"))
275        })
276        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
277
278    let mut response = client.send(request).await?;
279    if response.status().is_success() {
280        let reader = BufReader::new(response.into_body());
281
282        Ok(reader
283            .lines()
284            .map(|line| match line {
285                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
286                Err(e) => Err(e.into()),
287            })
288            .boxed())
289    } else {
290        let mut body = String::new();
291        response.body_mut().read_to_string(&mut body).await?;
292        anyhow::bail!(
293            "Failed to connect to Ollama API: {} {}",
294            response.status(),
295            body,
296        );
297    }
298}
299
300pub async fn get_models(
301    client: &dyn HttpClient,
302    api_url: &str,
303    api_key: Option<&str>,
304) -> Result<Vec<LocalModelListing>> {
305    let uri = format!("{api_url}/api/tags");
306    let request = HttpRequest::builder()
307        .method(Method::GET)
308        .uri(uri)
309        .header("Accept", "application/json")
310        .when_some(api_key, |builder, api_key| {
311            builder.header("Authorization", format!("Bearer {api_key}"))
312        })
313        .body(AsyncBody::default())?;
314
315    let mut response = client.send(request).await?;
316
317    let mut body = String::new();
318    response.body_mut().read_to_string(&mut body).await?;
319
320    anyhow::ensure!(
321        response.status().is_success(),
322        "Failed to connect to Ollama API: {} {}",
323        response.status(),
324        body,
325    );
326    let response: LocalModelsResponse =
327        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
328    Ok(response.models)
329}
330
331/// Fetch details of a model, used to determine model capabilities
332pub async fn show_model(
333    client: &dyn HttpClient,
334    api_url: &str,
335    api_key: Option<&str>,
336    model: &str,
337) -> Result<ModelShow> {
338    let uri = format!("{api_url}/api/show");
339    let request = HttpRequest::builder()
340        .method(Method::POST)
341        .uri(uri)
342        .header("Content-Type", "application/json")
343        .when_some(api_key, |builder, api_key| {
344            builder.header("Authorization", format!("Bearer {api_key}"))
345        })
346        .body(AsyncBody::from(
347            serde_json::json!({ "model": model }).to_string(),
348        ))?;
349
350    let mut response = client.send(request).await?;
351    let mut body = String::new();
352    response.body_mut().read_to_string(&mut body).await?;
353
354    anyhow::ensure!(
355        response.status().is_success(),
356        "Failed to connect to Ollama API: {} {}",
357        response.status(),
358        body,
359    );
360    let details: ModelShow = serde_json::from_str(body.as_str())?;
361    Ok(details)
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn parse_completion() {
370        let response = serde_json::json!({
371        "model": "llama3.2",
372        "created_at": "2023-12-12T14:13:43.416799Z",
373        "message": {
374            "role": "assistant",
375            "content": "Hello! How are you today?"
376        },
377        "done": true,
378        "total_duration": 5191566416u64,
379        "load_duration": 2154458,
380        "prompt_eval_count": 26,
381        "prompt_eval_duration": 383809000,
382        "eval_count": 298,
383        "eval_duration": 4799921000u64
384        });
385        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
386    }
387
388    #[test]
389    fn parse_streaming_completion() {
390        let partial = serde_json::json!({
391        "model": "llama3.2",
392        "created_at": "2023-08-04T08:52:19.385406455-07:00",
393        "message": {
394            "role": "assistant",
395            "content": "The",
396            "images": null
397        },
398        "done": false
399        });
400
401        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
402
403        let last = serde_json::json!({
404        "model": "llama3.2",
405        "created_at": "2023-08-04T19:22:45.499127Z",
406        "message": {
407            "role": "assistant",
408            "content": ""
409        },
410        "done": true,
411        "total_duration": 4883583458u64,
412        "load_duration": 1334875,
413        "prompt_eval_count": 26,
414        "prompt_eval_duration": 342546000,
415        "eval_count": 282,
416        "eval_duration": 4535599000u64
417        });
418
419        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
420    }
421
422    #[test]
423    fn parse_tool_call() {
424        let response = serde_json::json!({
425            "model": "llama3.2:3b",
426            "created_at": "2025-04-28T20:02:02.140489Z",
427            "message": {
428                "role": "assistant",
429                "content": "",
430                "tool_calls": [
431                    {
432                        "id": "call_llama3.2:3b_145155",
433                        "function": {
434                            "name": "weather",
435                            "arguments": {
436                                "city": "london",
437                            }
438                        }
439                    }
440                ]
441            },
442            "done_reason": "stop",
443            "done": true,
444            "total_duration": 2758629166u64,
445            "load_duration": 1770059875,
446            "prompt_eval_count": 147,
447            "prompt_eval_duration": 684637583,
448            "eval_count": 16,
449            "eval_duration": 302561917,
450        });
451
452        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
453        match result.message {
454            ChatMessage::Assistant {
455                content,
456                tool_calls,
457                images: _,
458                thinking,
459            } => {
460                assert!(content.is_empty());
461                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
462                assert!(thinking.is_none());
463            }
464            _ => panic!("Deserialized wrong role"),
465        }
466    }
467
468    // Backwards compatibility with Ollama versions prior to v0.12.10 November 2025
469    // This test is a copy of `parse_tool_call()` with the `id` field omitted.
470    #[test]
471    fn parse_tool_call_pre_0_12_10() {
472        let response = serde_json::json!({
473            "model": "llama3.2:3b",
474            "created_at": "2025-04-28T20:02:02.140489Z",
475            "message": {
476                "role": "assistant",
477                "content": "",
478                "tool_calls": [
479                    {
480                        "function": {
481                            "name": "weather",
482                            "arguments": {
483                                "city": "london",
484                            }
485                        }
486                    }
487                ]
488            },
489            "done_reason": "stop",
490            "done": true,
491            "total_duration": 2758629166u64,
492            "load_duration": 1770059875,
493            "prompt_eval_count": 147,
494            "prompt_eval_duration": 684637583,
495            "eval_count": 16,
496            "eval_duration": 302561917,
497        });
498
499        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
500        match result.message {
501            ChatMessage::Assistant {
502                content,
503                tool_calls: Some(tool_calls),
504                images: _,
505                thinking,
506            } => {
507                assert!(content.is_empty());
508                assert!(thinking.is_none());
509
510                // When the `Option` around `id` is removed, this test should complain
511                // and be subsequently deleted in favor of `parse_tool_call()`
512                assert!(tool_calls.first().is_some_and(|call| call.id.is_none()))
513            }
514            _ => panic!("Deserialized wrong role"),
515        }
516    }
517
518    #[test]
519    fn parse_show_model() {
520        let response = serde_json::json!({
521            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
522            "details": {
523                "parent_model": "",
524                "format": "gguf",
525                "family": "llama",
526                "families": ["llama"],
527                "parameter_size": "3.2B",
528                "quantization_level": "Q4_K_M"
529            },
530            "model_info": {
531                "general.architecture": "llama",
532                "general.basename": "Llama-3.2",
533                "general.file_type": 15,
534                "general.finetune": "Instruct",
535                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
536                "general.parameter_count": 3212749888u64,
537                "general.quantization_version": 2,
538                "general.size_label": "3B",
539                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
540                "general.type": "model",
541                "llama.attention.head_count": 24,
542                "llama.attention.head_count_kv": 8,
543                "llama.attention.key_length": 128,
544                "llama.attention.layer_norm_rms_epsilon": 0.00001,
545                "llama.attention.value_length": 128,
546                "llama.block_count": 28,
547                "llama.context_length": 131072,
548                "llama.embedding_length": 3072,
549                "llama.feed_forward_length": 8192,
550                "llama.rope.dimension_count": 128,
551                "llama.rope.freq_base": 500000,
552                "llama.vocab_size": 128256,
553                "tokenizer.ggml.bos_token_id": 128000,
554                "tokenizer.ggml.eos_token_id": 128009,
555                "tokenizer.ggml.merges": null,
556                "tokenizer.ggml.model": "gpt2",
557                "tokenizer.ggml.pre": "llama-bpe",
558                "tokenizer.ggml.token_type": null,
559                "tokenizer.ggml.tokens": null
560            },
561            "tensors": [
562                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
563                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
564            ],
565            "capabilities": ["completion", "tools"],
566            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
567        });
568
569        let result: ModelShow = serde_json::from_value(response).unwrap();
570        assert!(result.supports_tools());
571        assert!(result.capabilities.contains(&"tools".to_string()));
572        assert!(result.capabilities.contains(&"completion".to_string()));
573
574        assert_eq!(result.architecture, Some("llama".to_string()));
575        assert_eq!(result.context_length, Some(131072));
576    }
577
578    #[test]
579    fn serialize_chat_request_with_images() {
580        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
581
582        let request = ChatRequest {
583            model: "llava".to_string(),
584            messages: vec![ChatMessage::User {
585                content: "What do you see in this image?".to_string(),
586                images: Some(vec![base64_image.to_string()]),
587            }],
588            stream: false,
589            keep_alive: KeepAlive::default(),
590            options: None,
591            think: None,
592            tools: vec![],
593        };
594
595        let serialized = serde_json::to_string(&request).unwrap();
596        assert!(serialized.contains("images"));
597        assert!(serialized.contains(base64_image));
598    }
599
600    #[test]
601    fn serialize_chat_request_without_images() {
602        let request = ChatRequest {
603            model: "llama3.2".to_string(),
604            messages: vec![ChatMessage::User {
605                content: "Hello, world!".to_string(),
606                images: None,
607            }],
608            stream: false,
609            keep_alive: KeepAlive::default(),
610            options: None,
611            think: None,
612            tools: vec![],
613        };
614
615        let serialized = serde_json::to_string(&request).unwrap();
616        assert!(!serialized.contains("images"));
617    }
618
619    #[test]
620    fn test_json_format_with_images() {
621        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
622
623        let request = ChatRequest {
624            model: "llava".to_string(),
625            messages: vec![ChatMessage::User {
626                content: "What do you see?".to_string(),
627                images: Some(vec![base64_image.to_string()]),
628            }],
629            stream: false,
630            keep_alive: KeepAlive::default(),
631            options: None,
632            think: None,
633            tools: vec![],
634        };
635
636        let serialized = serde_json::to_string(&request).unwrap();
637
638        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
639        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
640        assert_eq!(message_images.len(), 1);
641        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
642    }
643}