ollama.rs

  1use anyhow::{Context, Result};
  2
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::KeepAlive;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(_name: &str) -> u64 {
 24    const DEFAULT_TOKENS: u64 = 4096;
 25    DEFAULT_TOKENS
 26}
 27
 28impl Model {
 29    pub fn new(
 30        name: &str,
 31        display_name: Option<&str>,
 32        max_tokens: Option<u64>,
 33        supports_tools: Option<bool>,
 34        supports_vision: Option<bool>,
 35        supports_thinking: Option<bool>,
 36    ) -> Self {
 37        Self {
 38            name: name.to_owned(),
 39            display_name: display_name
 40                .map(ToString::to_string)
 41                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 42            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 43            keep_alive: Some(KeepAlive::indefinite()),
 44            supports_tools,
 45            supports_vision,
 46            supports_thinking,
 47        }
 48    }
 49
 50    pub fn id(&self) -> &str {
 51        &self.name
 52    }
 53
 54    pub fn display_name(&self) -> &str {
 55        self.display_name.as_ref().unwrap_or(&self.name)
 56    }
 57
 58    pub fn max_token_count(&self) -> u64 {
 59        self.max_tokens
 60    }
 61}
 62
 63#[derive(Serialize, Deserialize, Debug)]
 64#[serde(tag = "role", rename_all = "lowercase")]
 65pub enum ChatMessage {
 66    Assistant {
 67        content: String,
 68        tool_calls: Option<Vec<OllamaToolCall>>,
 69        #[serde(skip_serializing_if = "Option::is_none")]
 70        images: Option<Vec<String>>,
 71        thinking: Option<String>,
 72    },
 73    User {
 74        content: String,
 75        #[serde(skip_serializing_if = "Option::is_none")]
 76        images: Option<Vec<String>>,
 77    },
 78    System {
 79        content: String,
 80    },
 81    Tool {
 82        tool_name: String,
 83        content: String,
 84    },
 85}
 86
 87#[derive(Serialize, Deserialize, Debug)]
 88pub struct OllamaToolCall {
 89    pub id: String,
 90    pub function: OllamaFunctionCall,
 91}
 92
 93#[derive(Serialize, Deserialize, Debug)]
 94pub struct OllamaFunctionCall {
 95    pub name: String,
 96    pub arguments: Value,
 97}
 98
 99#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
100pub struct OllamaFunctionTool {
101    pub name: String,
102    pub description: Option<String>,
103    pub parameters: Option<Value>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107#[serde(tag = "type", rename_all = "lowercase")]
108pub enum OllamaTool {
109    Function { function: OllamaFunctionTool },
110}
111
112#[derive(Serialize, Debug)]
113pub struct ChatRequest {
114    pub model: String,
115    pub messages: Vec<ChatMessage>,
116    pub stream: bool,
117    pub keep_alive: KeepAlive,
118    pub options: Option<ChatOptions>,
119    pub tools: Vec<OllamaTool>,
120    pub think: Option<bool>,
121}
122
123// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
124#[derive(Serialize, Default, Debug)]
125pub struct ChatOptions {
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub num_ctx: Option<u64>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub num_predict: Option<isize>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub stop: Option<Vec<String>>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub temperature: Option<f32>,
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub top_p: Option<f32>,
136}
137
138#[derive(Deserialize, Debug)]
139pub struct ChatResponseDelta {
140    pub model: String,
141    pub created_at: String,
142    pub message: ChatMessage,
143    pub done_reason: Option<String>,
144    pub done: bool,
145    pub prompt_eval_count: Option<u64>,
146    pub eval_count: Option<u64>,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct LocalModelsResponse {
151    pub models: Vec<LocalModelListing>,
152}
153
154#[derive(Serialize, Deserialize)]
155pub struct LocalModelListing {
156    pub name: String,
157    pub modified_at: String,
158    pub size: u64,
159    pub digest: String,
160    pub details: ModelDetails,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct LocalModel {
165    pub modelfile: String,
166    pub parameters: String,
167    pub template: String,
168    pub details: ModelDetails,
169}
170
171#[derive(Serialize, Deserialize)]
172pub struct ModelDetails {
173    pub format: String,
174    pub family: String,
175    pub families: Option<Vec<String>>,
176    pub parameter_size: String,
177    pub quantization_level: String,
178}
179
180#[derive(Debug)]
181pub struct ModelShow {
182    pub capabilities: Vec<String>,
183    pub context_length: Option<u64>,
184    pub architecture: Option<String>,
185}
186
187impl<'de> Deserialize<'de> for ModelShow {
188    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
189    where
190        D: serde::Deserializer<'de>,
191    {
192        use serde::de::{self, MapAccess, Visitor};
193        use std::fmt;
194
195        struct ModelShowVisitor;
196
197        impl<'de> Visitor<'de> for ModelShowVisitor {
198            type Value = ModelShow;
199
200            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
201                formatter.write_str("a ModelShow object")
202            }
203
204            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
205            where
206                A: MapAccess<'de>,
207            {
208                let mut capabilities: Vec<String> = Vec::new();
209                let mut architecture: Option<String> = None;
210                let mut context_length: Option<u64> = None;
211
212                while let Some(key) = map.next_key::<String>()? {
213                    match key.as_str() {
214                        "capabilities" => {
215                            capabilities = map.next_value()?;
216                        }
217                        "model_info" => {
218                            let model_info: Value = map.next_value()?;
219                            if let Value::Object(obj) = model_info {
220                                architecture = obj
221                                    .get("general.architecture")
222                                    .and_then(|v| v.as_str())
223                                    .map(String::from);
224
225                                if let Some(arch) = &architecture {
226                                    context_length = obj
227                                        .get(&format!("{}.context_length", arch))
228                                        .and_then(|v| v.as_u64());
229                                }
230                            }
231                        }
232                        _ => {
233                            let _: de::IgnoredAny = map.next_value()?;
234                        }
235                    }
236                }
237
238                Ok(ModelShow {
239                    capabilities,
240                    context_length,
241                    architecture,
242                })
243            }
244        }
245
246        deserializer.deserialize_map(ModelShowVisitor)
247    }
248}
249
250impl ModelShow {
251    pub fn supports_tools(&self) -> bool {
252        // .contains expects &String, which would require an additional allocation
253        self.capabilities.iter().any(|v| v == "tools")
254    }
255
256    pub fn supports_vision(&self) -> bool {
257        self.capabilities.iter().any(|v| v == "vision")
258    }
259
260    pub fn supports_thinking(&self) -> bool {
261        self.capabilities.iter().any(|v| v == "thinking")
262    }
263}
264
265pub async fn stream_chat_completion(
266    client: &dyn HttpClient,
267    api_url: &str,
268    api_key: Option<&str>,
269    request: ChatRequest,
270) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
271    let uri = format!("{api_url}/api/chat");
272    let request = HttpRequest::builder()
273        .method(Method::POST)
274        .uri(uri)
275        .header("Content-Type", "application/json")
276        .when_some(api_key, |builder, api_key| {
277            builder.header("Authorization", format!("Bearer {api_key}"))
278        })
279        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
280
281    let mut response = client.send(request).await?;
282    if response.status().is_success() {
283        let reader = BufReader::new(response.into_body());
284
285        Ok(reader
286            .lines()
287            .map(|line| match line {
288                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
289                Err(e) => Err(e.into()),
290            })
291            .boxed())
292    } else {
293        let mut body = String::new();
294        response.body_mut().read_to_string(&mut body).await?;
295        anyhow::bail!(
296            "Failed to connect to Ollama API: {} {}",
297            response.status(),
298            body,
299        );
300    }
301}
302
303pub async fn get_models(
304    client: &dyn HttpClient,
305    api_url: &str,
306    api_key: Option<&str>,
307) -> Result<Vec<LocalModelListing>> {
308    let uri = format!("{api_url}/api/tags");
309    let request = HttpRequest::builder()
310        .method(Method::GET)
311        .uri(uri)
312        .header("Accept", "application/json")
313        .when_some(api_key, |builder, api_key| {
314            builder.header("Authorization", format!("Bearer {api_key}"))
315        })
316        .body(AsyncBody::default())?;
317
318    let mut response = client.send(request).await?;
319
320    let mut body = String::new();
321    response.body_mut().read_to_string(&mut body).await?;
322
323    anyhow::ensure!(
324        response.status().is_success(),
325        "Failed to connect to Ollama API: {} {}",
326        response.status(),
327        body,
328    );
329    let response: LocalModelsResponse =
330        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
331    Ok(response.models)
332}
333
334/// Fetch details of a model, used to determine model capabilities
335pub async fn show_model(
336    client: &dyn HttpClient,
337    api_url: &str,
338    api_key: Option<&str>,
339    model: &str,
340) -> Result<ModelShow> {
341    let uri = format!("{api_url}/api/show");
342    let request = HttpRequest::builder()
343        .method(Method::POST)
344        .uri(uri)
345        .header("Content-Type", "application/json")
346        .when_some(api_key, |builder, api_key| {
347            builder.header("Authorization", format!("Bearer {api_key}"))
348        })
349        .body(AsyncBody::from(
350            serde_json::json!({ "model": model }).to_string(),
351        ))?;
352
353    let mut response = client.send(request).await?;
354    let mut body = String::new();
355    response.body_mut().read_to_string(&mut body).await?;
356
357    anyhow::ensure!(
358        response.status().is_success(),
359        "Failed to connect to Ollama API: {} {}",
360        response.status(),
361        body,
362    );
363    let details: ModelShow = serde_json::from_str(body.as_str())?;
364    Ok(details)
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn parse_completion() {
373        let response = serde_json::json!({
374        "model": "llama3.2",
375        "created_at": "2023-12-12T14:13:43.416799Z",
376        "message": {
377            "role": "assistant",
378            "content": "Hello! How are you today?"
379        },
380        "done": true,
381        "total_duration": 5191566416u64,
382        "load_duration": 2154458,
383        "prompt_eval_count": 26,
384        "prompt_eval_duration": 383809000,
385        "eval_count": 298,
386        "eval_duration": 4799921000u64
387        });
388        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
389    }
390
391    #[test]
392    fn parse_streaming_completion() {
393        let partial = serde_json::json!({
394        "model": "llama3.2",
395        "created_at": "2023-08-04T08:52:19.385406455-07:00",
396        "message": {
397            "role": "assistant",
398            "content": "The",
399            "images": null
400        },
401        "done": false
402        });
403
404        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
405
406        let last = serde_json::json!({
407        "model": "llama3.2",
408        "created_at": "2023-08-04T19:22:45.499127Z",
409        "message": {
410            "role": "assistant",
411            "content": ""
412        },
413        "done": true,
414        "total_duration": 4883583458u64,
415        "load_duration": 1334875,
416        "prompt_eval_count": 26,
417        "prompt_eval_duration": 342546000,
418        "eval_count": 282,
419        "eval_duration": 4535599000u64
420        });
421
422        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
423    }
424
425    #[test]
426    fn parse_tool_call() {
427        let response = serde_json::json!({
428            "model": "llama3.2:3b",
429            "created_at": "2025-04-28T20:02:02.140489Z",
430            "message": {
431                "role": "assistant",
432                "content": "",
433                "tool_calls": [
434                    {
435                        "id": "call_llama3.2:3b_145155",
436                        "function": {
437                            "name": "weather",
438                            "arguments": {
439                                "city": "london",
440                            }
441                        }
442                    }
443                ]
444            },
445            "done_reason": "stop",
446            "done": true,
447            "total_duration": 2758629166u64,
448            "load_duration": 1770059875,
449            "prompt_eval_count": 147,
450            "prompt_eval_duration": 684637583,
451            "eval_count": 16,
452            "eval_duration": 302561917,
453        });
454
455        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
456        match result.message {
457            ChatMessage::Assistant {
458                content,
459                tool_calls,
460                images: _,
461                thinking,
462            } => {
463                assert!(content.is_empty());
464                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
465                assert!(thinking.is_none());
466            }
467            _ => panic!("Deserialized wrong role"),
468        }
469    }
470
471    #[test]
472    fn parse_show_model() {
473        let response = serde_json::json!({
474            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
475            "details": {
476                "parent_model": "",
477                "format": "gguf",
478                "family": "llama",
479                "families": ["llama"],
480                "parameter_size": "3.2B",
481                "quantization_level": "Q4_K_M"
482            },
483            "model_info": {
484                "general.architecture": "llama",
485                "general.basename": "Llama-3.2",
486                "general.file_type": 15,
487                "general.finetune": "Instruct",
488                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
489                "general.parameter_count": 3212749888u64,
490                "general.quantization_version": 2,
491                "general.size_label": "3B",
492                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
493                "general.type": "model",
494                "llama.attention.head_count": 24,
495                "llama.attention.head_count_kv": 8,
496                "llama.attention.key_length": 128,
497                "llama.attention.layer_norm_rms_epsilon": 0.00001,
498                "llama.attention.value_length": 128,
499                "llama.block_count": 28,
500                "llama.context_length": 131072,
501                "llama.embedding_length": 3072,
502                "llama.feed_forward_length": 8192,
503                "llama.rope.dimension_count": 128,
504                "llama.rope.freq_base": 500000,
505                "llama.vocab_size": 128256,
506                "tokenizer.ggml.bos_token_id": 128000,
507                "tokenizer.ggml.eos_token_id": 128009,
508                "tokenizer.ggml.merges": null,
509                "tokenizer.ggml.model": "gpt2",
510                "tokenizer.ggml.pre": "llama-bpe",
511                "tokenizer.ggml.token_type": null,
512                "tokenizer.ggml.tokens": null
513            },
514            "tensors": [
515                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
516                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
517            ],
518            "capabilities": ["completion", "tools"],
519            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
520        });
521
522        let result: ModelShow = serde_json::from_value(response).unwrap();
523        assert!(result.supports_tools());
524        assert!(result.capabilities.contains(&"tools".to_string()));
525        assert!(result.capabilities.contains(&"completion".to_string()));
526
527        assert_eq!(result.architecture, Some("llama".to_string()));
528        assert_eq!(result.context_length, Some(131072));
529    }
530
531    #[test]
532    fn serialize_chat_request_with_images() {
533        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
534
535        let request = ChatRequest {
536            model: "llava".to_string(),
537            messages: vec![ChatMessage::User {
538                content: "What do you see in this image?".to_string(),
539                images: Some(vec![base64_image.to_string()]),
540            }],
541            stream: false,
542            keep_alive: KeepAlive::default(),
543            options: None,
544            think: None,
545            tools: vec![],
546        };
547
548        let serialized = serde_json::to_string(&request).unwrap();
549        assert!(serialized.contains("images"));
550        assert!(serialized.contains(base64_image));
551    }
552
553    #[test]
554    fn serialize_chat_request_without_images() {
555        let request = ChatRequest {
556            model: "llama3.2".to_string(),
557            messages: vec![ChatMessage::User {
558                content: "Hello, world!".to_string(),
559                images: None,
560            }],
561            stream: false,
562            keep_alive: KeepAlive::default(),
563            options: None,
564            think: None,
565            tools: vec![],
566        };
567
568        let serialized = serde_json::to_string(&request).unwrap();
569        assert!(!serialized.contains("images"));
570    }
571
572    #[test]
573    fn test_json_format_with_images() {
574        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
575
576        let request = ChatRequest {
577            model: "llava".to_string(),
578            messages: vec![ChatMessage::User {
579                content: "What do you see?".to_string(),
580                images: Some(vec![base64_image.to_string()]),
581            }],
582            stream: false,
583            keep_alive: KeepAlive::default(),
584            options: None,
585            think: None,
586            tools: vec![],
587        };
588
589        let serialized = serde_json::to_string(&request).unwrap();
590
591        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
592        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
593        assert_eq!(message_images.len(), 1);
594        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
595    }
596
597    #[test]
598    fn test_chat_options_serialization() {
599        // When stop is None, it should not appear in JSON at all
600        // This allows Ollama to use the model's default stop tokens
601        let options_no_stop = ChatOptions {
602            num_ctx: Some(4096),
603            stop: None,
604            temperature: Some(0.7),
605            ..Default::default()
606        };
607        let serialized = serde_json::to_string(&options_no_stop).unwrap();
608        assert!(
609            !serialized.contains("stop"),
610            "stop should not be in JSON when None"
611        );
612        assert!(serialized.contains("num_ctx"));
613        assert!(serialized.contains("temperature"));
614
615        // When stop has values, they should be serialized
616        let options_with_stop = ChatOptions {
617            stop: Some(vec!["<|eot_id|>".to_string()]),
618            ..Default::default()
619        };
620        let serialized = serde_json::to_string(&options_with_stop).unwrap();
621        assert!(serialized.contains("stop"));
622        assert!(serialized.contains("<|eot_id|>"));
623
624        // All None options should result in empty object
625        let options_all_none = ChatOptions::default();
626        let serialized = serde_json::to_string(&options_all_none).unwrap();
627        assert_eq!(serialized, "{}");
628    }
629
630    #[test]
631    fn test_chat_request_with_stop_tokens() {
632        let request = ChatRequest {
633            model: "rnj-1:8b".to_string(),
634            messages: vec![ChatMessage::User {
635                content: "Hello".to_string(),
636                images: None,
637            }],
638            stream: true,
639            keep_alive: KeepAlive::default(),
640            options: Some(ChatOptions {
641                stop: Some(vec!["<|eot_id|>".to_string(), "<|end|>".to_string()]),
642                ..Default::default()
643            }),
644            think: None,
645            tools: vec![],
646        };
647
648        let serialized = serde_json::to_string(&request).unwrap();
649        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
650
651        let stop = parsed["options"]["stop"].as_array().unwrap();
652        assert_eq!(stop.len(), 2);
653        assert_eq!(stop[0].as_str().unwrap(), "<|eot_id|>");
654        assert_eq!(stop[1].as_str().unwrap(), "<|end|>");
655    }
656
657    #[test]
658    fn test_chat_request_without_stop_tokens_omits_field() {
659        // This tests the fix for issue #47798
660        // When no stop tokens are provided, the field should be omitted
661        // so Ollama uses the model's default stop tokens from Modelfile
662        let request = ChatRequest {
663            model: "rnj-1:8b".to_string(),
664            messages: vec![ChatMessage::User {
665                content: "Hello".to_string(),
666                images: None,
667            }],
668            stream: true,
669            keep_alive: KeepAlive::default(),
670            options: Some(ChatOptions {
671                num_ctx: Some(4096),
672                stop: None, // No stop tokens - should be omitted from JSON
673                ..Default::default()
674            }),
675            think: None,
676            tools: vec![],
677        };
678
679        let serialized = serde_json::to_string(&request).unwrap();
680
681        // The key check: "stop" should not appear in the serialized JSON
682        assert!(
683            !serialized.contains("\"stop\""),
684            "stop field should be omitted when None, got: {}",
685            serialized
686        );
687    }
688}