ollama.rs

  1use anyhow::{Context, Result};
  2
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::KeepAlive;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(_name: &str) -> u64 {
 24    const DEFAULT_TOKENS: u64 = 4096;
 25    DEFAULT_TOKENS
 26}
 27
 28impl Model {
 29    pub fn new(
 30        name: &str,
 31        display_name: Option<&str>,
 32        max_tokens: Option<u64>,
 33        supports_tools: Option<bool>,
 34        supports_vision: Option<bool>,
 35        supports_thinking: Option<bool>,
 36    ) -> Self {
 37        Self {
 38            name: name.to_owned(),
 39            display_name: display_name
 40                .map(ToString::to_string)
 41                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 42            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 43            keep_alive: Some(KeepAlive::indefinite()),
 44            supports_tools,
 45            supports_vision,
 46            supports_thinking,
 47        }
 48    }
 49
 50    pub fn id(&self) -> &str {
 51        &self.name
 52    }
 53
 54    pub fn display_name(&self) -> &str {
 55        self.display_name.as_ref().unwrap_or(&self.name)
 56    }
 57
 58    pub fn max_token_count(&self) -> u64 {
 59        self.max_tokens
 60    }
 61}
 62
 63#[derive(Serialize, Deserialize, Debug)]
 64#[serde(tag = "role", rename_all = "lowercase")]
 65pub enum ChatMessage {
 66    Assistant {
 67        content: String,
 68        tool_calls: Option<Vec<OllamaToolCall>>,
 69        #[serde(skip_serializing_if = "Option::is_none")]
 70        images: Option<Vec<String>>,
 71        thinking: Option<String>,
 72    },
 73    User {
 74        content: String,
 75        #[serde(skip_serializing_if = "Option::is_none")]
 76        images: Option<Vec<String>>,
 77    },
 78    System {
 79        content: String,
 80    },
 81    Tool {
 82        tool_name: String,
 83        content: String,
 84    },
 85}
 86
 87#[derive(Serialize, Deserialize, Debug)]
 88pub struct OllamaToolCall {
 89    pub id: String,
 90    pub function: OllamaFunctionCall,
 91}
 92
 93#[derive(Serialize, Deserialize, Debug)]
 94pub struct OllamaFunctionCall {
 95    pub name: String,
 96    pub arguments: Value,
 97}
 98
 99#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
100pub struct OllamaFunctionTool {
101    pub name: String,
102    pub description: Option<String>,
103    pub parameters: Option<Value>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107#[serde(tag = "type", rename_all = "lowercase")]
108pub enum OllamaTool {
109    Function { function: OllamaFunctionTool },
110}
111
112#[derive(Serialize, Debug)]
113pub struct ChatRequest {
114    pub model: String,
115    pub messages: Vec<ChatMessage>,
116    pub stream: bool,
117    pub keep_alive: KeepAlive,
118    pub options: Option<ChatOptions>,
119    pub tools: Vec<OllamaTool>,
120    pub think: Option<bool>,
121}
122
123// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
124#[derive(Serialize, Default, Debug)]
125pub struct ChatOptions {
126    pub num_ctx: Option<u64>,
127    pub num_predict: Option<isize>,
128    pub stop: Option<Vec<String>>,
129    pub temperature: Option<f32>,
130    pub top_p: Option<f32>,
131}
132
133#[derive(Deserialize, Debug)]
134pub struct ChatResponseDelta {
135    pub model: String,
136    pub created_at: String,
137    pub message: ChatMessage,
138    pub done_reason: Option<String>,
139    pub done: bool,
140    pub prompt_eval_count: Option<u64>,
141    pub eval_count: Option<u64>,
142}
143
144#[derive(Serialize, Deserialize)]
145pub struct LocalModelsResponse {
146    pub models: Vec<LocalModelListing>,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct LocalModelListing {
151    pub name: String,
152    pub modified_at: String,
153    pub size: u64,
154    pub digest: String,
155    pub details: ModelDetails,
156}
157
158#[derive(Serialize, Deserialize)]
159pub struct LocalModel {
160    pub modelfile: String,
161    pub parameters: String,
162    pub template: String,
163    pub details: ModelDetails,
164}
165
166#[derive(Serialize, Deserialize)]
167pub struct ModelDetails {
168    pub format: String,
169    pub family: String,
170    pub families: Option<Vec<String>>,
171    pub parameter_size: String,
172    pub quantization_level: String,
173}
174
175#[derive(Debug)]
176pub struct ModelShow {
177    pub capabilities: Vec<String>,
178    pub context_length: Option<u64>,
179    pub architecture: Option<String>,
180}
181
182impl<'de> Deserialize<'de> for ModelShow {
183    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
184    where
185        D: serde::Deserializer<'de>,
186    {
187        use serde::de::{self, MapAccess, Visitor};
188        use std::fmt;
189
190        struct ModelShowVisitor;
191
192        impl<'de> Visitor<'de> for ModelShowVisitor {
193            type Value = ModelShow;
194
195            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
196                formatter.write_str("a ModelShow object")
197            }
198
199            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
200            where
201                A: MapAccess<'de>,
202            {
203                let mut capabilities: Vec<String> = Vec::new();
204                let mut architecture: Option<String> = None;
205                let mut context_length: Option<u64> = None;
206
207                while let Some(key) = map.next_key::<String>()? {
208                    match key.as_str() {
209                        "capabilities" => {
210                            capabilities = map.next_value()?;
211                        }
212                        "model_info" => {
213                            let model_info: Value = map.next_value()?;
214                            if let Value::Object(obj) = model_info {
215                                architecture = obj
216                                    .get("general.architecture")
217                                    .and_then(|v| v.as_str())
218                                    .map(String::from);
219
220                                if let Some(arch) = &architecture {
221                                    context_length = obj
222                                        .get(&format!("{}.context_length", arch))
223                                        .and_then(|v| v.as_u64());
224                                }
225                            }
226                        }
227                        _ => {
228                            let _: de::IgnoredAny = map.next_value()?;
229                        }
230                    }
231                }
232
233                Ok(ModelShow {
234                    capabilities,
235                    context_length,
236                    architecture,
237                })
238            }
239        }
240
241        deserializer.deserialize_map(ModelShowVisitor)
242    }
243}
244
245impl ModelShow {
246    pub fn supports_tools(&self) -> bool {
247        // .contains expects &String, which would require an additional allocation
248        self.capabilities.iter().any(|v| v == "tools")
249    }
250
251    pub fn supports_vision(&self) -> bool {
252        self.capabilities.iter().any(|v| v == "vision")
253    }
254
255    pub fn supports_thinking(&self) -> bool {
256        self.capabilities.iter().any(|v| v == "thinking")
257    }
258}
259
260pub async fn stream_chat_completion(
261    client: &dyn HttpClient,
262    api_url: &str,
263    api_key: Option<&str>,
264    request: ChatRequest,
265) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
266    let uri = format!("{api_url}/api/chat");
267    let request = HttpRequest::builder()
268        .method(Method::POST)
269        .uri(uri)
270        .header("Content-Type", "application/json")
271        .when_some(api_key, |builder, api_key| {
272            builder.header("Authorization", format!("Bearer {api_key}"))
273        })
274        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
275
276    let mut response = client.send(request).await?;
277    if response.status().is_success() {
278        let reader = BufReader::new(response.into_body());
279
280        Ok(reader
281            .lines()
282            .map(|line| match line {
283                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
284                Err(e) => Err(e.into()),
285            })
286            .boxed())
287    } else {
288        let mut body = String::new();
289        response.body_mut().read_to_string(&mut body).await?;
290        anyhow::bail!(
291            "Failed to connect to Ollama API: {} {}",
292            response.status(),
293            body,
294        );
295    }
296}
297
298pub async fn get_models(
299    client: &dyn HttpClient,
300    api_url: &str,
301    api_key: Option<&str>,
302) -> Result<Vec<LocalModelListing>> {
303    let uri = format!("{api_url}/api/tags");
304    let request = HttpRequest::builder()
305        .method(Method::GET)
306        .uri(uri)
307        .header("Accept", "application/json")
308        .when_some(api_key, |builder, api_key| {
309            builder.header("Authorization", format!("Bearer {api_key}"))
310        })
311        .body(AsyncBody::default())?;
312
313    let mut response = client.send(request).await?;
314
315    let mut body = String::new();
316    response.body_mut().read_to_string(&mut body).await?;
317
318    anyhow::ensure!(
319        response.status().is_success(),
320        "Failed to connect to Ollama API: {} {}",
321        response.status(),
322        body,
323    );
324    let response: LocalModelsResponse =
325        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
326    Ok(response.models)
327}
328
329/// Fetch details of a model, used to determine model capabilities
330pub async fn show_model(
331    client: &dyn HttpClient,
332    api_url: &str,
333    api_key: Option<&str>,
334    model: &str,
335) -> Result<ModelShow> {
336    let uri = format!("{api_url}/api/show");
337    let request = HttpRequest::builder()
338        .method(Method::POST)
339        .uri(uri)
340        .header("Content-Type", "application/json")
341        .when_some(api_key, |builder, api_key| {
342            builder.header("Authorization", format!("Bearer {api_key}"))
343        })
344        .body(AsyncBody::from(
345            serde_json::json!({ "model": model }).to_string(),
346        ))?;
347
348    let mut response = client.send(request).await?;
349    let mut body = String::new();
350    response.body_mut().read_to_string(&mut body).await?;
351
352    anyhow::ensure!(
353        response.status().is_success(),
354        "Failed to connect to Ollama API: {} {}",
355        response.status(),
356        body,
357    );
358    let details: ModelShow = serde_json::from_str(body.as_str())?;
359    Ok(details)
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn parse_completion() {
368        let response = serde_json::json!({
369        "model": "llama3.2",
370        "created_at": "2023-12-12T14:13:43.416799Z",
371        "message": {
372            "role": "assistant",
373            "content": "Hello! How are you today?"
374        },
375        "done": true,
376        "total_duration": 5191566416u64,
377        "load_duration": 2154458,
378        "prompt_eval_count": 26,
379        "prompt_eval_duration": 383809000,
380        "eval_count": 298,
381        "eval_duration": 4799921000u64
382        });
383        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
384    }
385
386    #[test]
387    fn parse_streaming_completion() {
388        let partial = serde_json::json!({
389        "model": "llama3.2",
390        "created_at": "2023-08-04T08:52:19.385406455-07:00",
391        "message": {
392            "role": "assistant",
393            "content": "The",
394            "images": null
395        },
396        "done": false
397        });
398
399        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
400
401        let last = serde_json::json!({
402        "model": "llama3.2",
403        "created_at": "2023-08-04T19:22:45.499127Z",
404        "message": {
405            "role": "assistant",
406            "content": ""
407        },
408        "done": true,
409        "total_duration": 4883583458u64,
410        "load_duration": 1334875,
411        "prompt_eval_count": 26,
412        "prompt_eval_duration": 342546000,
413        "eval_count": 282,
414        "eval_duration": 4535599000u64
415        });
416
417        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
418    }
419
420    #[test]
421    fn parse_tool_call() {
422        let response = serde_json::json!({
423            "model": "llama3.2:3b",
424            "created_at": "2025-04-28T20:02:02.140489Z",
425            "message": {
426                "role": "assistant",
427                "content": "",
428                "tool_calls": [
429                    {
430                        "id": "call_llama3.2:3b_145155",
431                        "function": {
432                            "name": "weather",
433                            "arguments": {
434                                "city": "london",
435                            }
436                        }
437                    }
438                ]
439            },
440            "done_reason": "stop",
441            "done": true,
442            "total_duration": 2758629166u64,
443            "load_duration": 1770059875,
444            "prompt_eval_count": 147,
445            "prompt_eval_duration": 684637583,
446            "eval_count": 16,
447            "eval_duration": 302561917,
448        });
449
450        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
451        match result.message {
452            ChatMessage::Assistant {
453                content,
454                tool_calls,
455                images: _,
456                thinking,
457            } => {
458                assert!(content.is_empty());
459                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
460                assert!(thinking.is_none());
461            }
462            _ => panic!("Deserialized wrong role"),
463        }
464    }
465
466    #[test]
467    fn parse_show_model() {
468        let response = serde_json::json!({
469            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
470            "details": {
471                "parent_model": "",
472                "format": "gguf",
473                "family": "llama",
474                "families": ["llama"],
475                "parameter_size": "3.2B",
476                "quantization_level": "Q4_K_M"
477            },
478            "model_info": {
479                "general.architecture": "llama",
480                "general.basename": "Llama-3.2",
481                "general.file_type": 15,
482                "general.finetune": "Instruct",
483                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
484                "general.parameter_count": 3212749888u64,
485                "general.quantization_version": 2,
486                "general.size_label": "3B",
487                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
488                "general.type": "model",
489                "llama.attention.head_count": 24,
490                "llama.attention.head_count_kv": 8,
491                "llama.attention.key_length": 128,
492                "llama.attention.layer_norm_rms_epsilon": 0.00001,
493                "llama.attention.value_length": 128,
494                "llama.block_count": 28,
495                "llama.context_length": 131072,
496                "llama.embedding_length": 3072,
497                "llama.feed_forward_length": 8192,
498                "llama.rope.dimension_count": 128,
499                "llama.rope.freq_base": 500000,
500                "llama.vocab_size": 128256,
501                "tokenizer.ggml.bos_token_id": 128000,
502                "tokenizer.ggml.eos_token_id": 128009,
503                "tokenizer.ggml.merges": null,
504                "tokenizer.ggml.model": "gpt2",
505                "tokenizer.ggml.pre": "llama-bpe",
506                "tokenizer.ggml.token_type": null,
507                "tokenizer.ggml.tokens": null
508            },
509            "tensors": [
510                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
511                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
512            ],
513            "capabilities": ["completion", "tools"],
514            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
515        });
516
517        let result: ModelShow = serde_json::from_value(response).unwrap();
518        assert!(result.supports_tools());
519        assert!(result.capabilities.contains(&"tools".to_string()));
520        assert!(result.capabilities.contains(&"completion".to_string()));
521
522        assert_eq!(result.architecture, Some("llama".to_string()));
523        assert_eq!(result.context_length, Some(131072));
524    }
525
526    #[test]
527    fn serialize_chat_request_with_images() {
528        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
529
530        let request = ChatRequest {
531            model: "llava".to_string(),
532            messages: vec![ChatMessage::User {
533                content: "What do you see in this image?".to_string(),
534                images: Some(vec![base64_image.to_string()]),
535            }],
536            stream: false,
537            keep_alive: KeepAlive::default(),
538            options: None,
539            think: None,
540            tools: vec![],
541        };
542
543        let serialized = serde_json::to_string(&request).unwrap();
544        assert!(serialized.contains("images"));
545        assert!(serialized.contains(base64_image));
546    }
547
548    #[test]
549    fn serialize_chat_request_without_images() {
550        let request = ChatRequest {
551            model: "llama3.2".to_string(),
552            messages: vec![ChatMessage::User {
553                content: "Hello, world!".to_string(),
554                images: None,
555            }],
556            stream: false,
557            keep_alive: KeepAlive::default(),
558            options: None,
559            think: None,
560            tools: vec![],
561        };
562
563        let serialized = serde_json::to_string(&request).unwrap();
564        assert!(!serialized.contains("images"));
565    }
566
567    #[test]
568    fn test_json_format_with_images() {
569        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
570
571        let request = ChatRequest {
572            model: "llava".to_string(),
573            messages: vec![ChatMessage::User {
574                content: "What do you see?".to_string(),
575                images: Some(vec![base64_image.to_string()]),
576            }],
577            stream: false,
578            keep_alive: KeepAlive::default(),
579            options: None,
580            think: None,
581            tools: vec![],
582        };
583
584        let serialized = serde_json::to_string(&request).unwrap();
585
586        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
587        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
588        assert_eq!(message_images.len(), 1);
589        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
590    }
591}