ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::time::Duration;
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(untagged)]
 13pub enum KeepAlive {
 14    /// Keep model alive for N seconds
 15    Seconds(isize),
 16    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 17    Duration(String),
 18}
 19
 20impl KeepAlive {
 21    /// Keep model alive until a new model is loaded or until Ollama shuts down
 22    fn indefinite() -> Self {
 23        Self::Seconds(-1)
 24    }
 25}
 26
 27impl Default for KeepAlive {
 28    fn default() -> Self {
 29        Self::indefinite()
 30    }
 31}
 32
 33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 35pub struct Model {
 36    pub name: String,
 37    pub display_name: Option<String>,
 38    pub max_tokens: u64,
 39    pub keep_alive: Option<KeepAlive>,
 40    pub supports_tools: Option<bool>,
 41    pub supports_vision: Option<bool>,
 42    pub supports_thinking: Option<bool>,
 43}
 44
 45fn get_max_tokens(name: &str) -> u64 {
 46    /// Default context length for unknown models.
 47    const DEFAULT_TOKENS: u64 = 4096;
 48    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 49    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 50    const MAXIMUM_TOKENS: u64 = 16384;
 51
 52    match name.split(':').next().unwrap() {
 53        "granite-code" | "phi" | "tinyllama" => 2048,
 54        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 55        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 56        "codellama" | "starcoder2" => 16384,
 57        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 58        | "qwen2" | "qwen2.5-coder" => 32768,
 59        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 60        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 61        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 62        _ => DEFAULT_TOKENS,
 63    }
 64    .clamp(1, MAXIMUM_TOKENS)
 65}
 66
 67impl Model {
 68    pub fn new(
 69        name: &str,
 70        display_name: Option<&str>,
 71        max_tokens: Option<u64>,
 72        supports_tools: Option<bool>,
 73        supports_vision: Option<bool>,
 74        supports_thinking: Option<bool>,
 75    ) -> Self {
 76        Self {
 77            name: name.to_owned(),
 78            display_name: display_name
 79                .map(ToString::to_string)
 80                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 81            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 82            keep_alive: Some(KeepAlive::indefinite()),
 83            supports_tools,
 84            supports_vision,
 85            supports_thinking,
 86        }
 87    }
 88
 89    pub fn id(&self) -> &str {
 90        &self.name
 91    }
 92
 93    pub fn display_name(&self) -> &str {
 94        self.display_name.as_ref().unwrap_or(&self.name)
 95    }
 96
 97    pub fn max_token_count(&self) -> u64 {
 98        self.max_tokens
 99    }
100}
101
102#[derive(Serialize, Deserialize, Debug)]
103#[serde(tag = "role", rename_all = "lowercase")]
104pub enum ChatMessage {
105    Assistant {
106        content: String,
107        tool_calls: Option<Vec<OllamaToolCall>>,
108        #[serde(skip_serializing_if = "Option::is_none")]
109        images: Option<Vec<String>>,
110        thinking: Option<String>,
111    },
112    User {
113        content: String,
114        #[serde(skip_serializing_if = "Option::is_none")]
115        images: Option<Vec<String>>,
116    },
117    System {
118        content: String,
119    },
120    Tool {
121        tool_name: String,
122        content: String,
123    },
124}
125
126#[derive(Serialize, Deserialize, Debug)]
127#[serde(rename_all = "lowercase")]
128pub enum OllamaToolCall {
129    Function(OllamaFunctionCall),
130}
131
132#[derive(Serialize, Deserialize, Debug)]
133pub struct OllamaFunctionCall {
134    pub name: String,
135    pub arguments: Value,
136}
137
138#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
139pub struct OllamaFunctionTool {
140    pub name: String,
141    pub description: Option<String>,
142    pub parameters: Option<Value>,
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146#[serde(tag = "type", rename_all = "lowercase")]
147pub enum OllamaTool {
148    Function { function: OllamaFunctionTool },
149}
150
151#[derive(Serialize, Debug)]
152pub struct ChatRequest {
153    pub model: String,
154    pub messages: Vec<ChatMessage>,
155    pub stream: bool,
156    pub keep_alive: KeepAlive,
157    pub options: Option<ChatOptions>,
158    pub tools: Vec<OllamaTool>,
159    pub think: Option<bool>,
160}
161
162impl ChatRequest {
163    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
164        self.stream = false;
165        self.tools = tools;
166        self
167    }
168}
169
170// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
171#[derive(Serialize, Default, Debug)]
172pub struct ChatOptions {
173    pub num_ctx: Option<u64>,
174    pub num_predict: Option<isize>,
175    pub stop: Option<Vec<String>>,
176    pub temperature: Option<f32>,
177    pub top_p: Option<f32>,
178}
179
180#[derive(Deserialize, Debug)]
181pub struct ChatResponseDelta {
182    #[allow(unused)]
183    pub model: String,
184    #[allow(unused)]
185    pub created_at: String,
186    pub message: ChatMessage,
187    #[allow(unused)]
188    pub done_reason: Option<String>,
189    #[allow(unused)]
190    pub done: bool,
191    pub prompt_eval_count: Option<u64>,
192    pub eval_count: Option<u64>,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct LocalModelsResponse {
197    pub models: Vec<LocalModelListing>,
198}
199
200#[derive(Serialize, Deserialize)]
201pub struct LocalModelListing {
202    pub name: String,
203    pub modified_at: String,
204    pub size: u64,
205    pub digest: String,
206    pub details: ModelDetails,
207}
208
209#[derive(Serialize, Deserialize)]
210pub struct LocalModel {
211    pub modelfile: String,
212    pub parameters: String,
213    pub template: String,
214    pub details: ModelDetails,
215}
216
217#[derive(Serialize, Deserialize)]
218pub struct ModelDetails {
219    pub format: String,
220    pub family: String,
221    pub families: Option<Vec<String>>,
222    pub parameter_size: String,
223    pub quantization_level: String,
224}
225
226#[derive(Deserialize, Debug)]
227pub struct ModelShow {
228    #[serde(default)]
229    pub capabilities: Vec<String>,
230}
231
232impl ModelShow {
233    pub fn supports_tools(&self) -> bool {
234        // .contains expects &String, which would require an additional allocation
235        self.capabilities.iter().any(|v| v == "tools")
236    }
237
238    pub fn supports_vision(&self) -> bool {
239        self.capabilities.iter().any(|v| v == "vision")
240    }
241
242    pub fn supports_thinking(&self) -> bool {
243        self.capabilities.iter().any(|v| v == "thinking")
244    }
245}
246
247pub async fn complete(
248    client: &dyn HttpClient,
249    api_url: &str,
250    request: ChatRequest,
251) -> Result<ChatResponseDelta> {
252    let uri = format!("{api_url}/api/chat");
253    let request_builder = HttpRequest::builder()
254        .method(Method::POST)
255        .uri(uri)
256        .header("Content-Type", "application/json");
257
258    let serialized_request = serde_json::to_string(&request)?;
259    let request = request_builder.body(AsyncBody::from(serialized_request))?;
260
261    let mut response = client.send(request).await?;
262
263    let mut body = Vec::new();
264    response.body_mut().read_to_end(&mut body).await?;
265
266    if response.status().is_success() {
267        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
268        Ok(response_message)
269    } else {
270        let body_str = std::str::from_utf8(&body)?;
271        anyhow::bail!(
272            "Failed to connect to API: {} {}",
273            response.status(),
274            body_str
275        );
276    }
277}
278
279pub async fn stream_chat_completion(
280    client: &dyn HttpClient,
281    api_url: &str,
282    api_key: Option<&str>,
283    request: ChatRequest,
284) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
285    let uri = format!("{api_url}/api/chat");
286    let mut request_builder = http::Request::builder()
287        .method(Method::POST)
288        .uri(uri)
289        .header("Content-Type", "application/json");
290
291    if let Some(api_key) = api_key {
292        request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
293    }
294
295    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
296    let mut response = client.send(request).await?;
297    if response.status().is_success() {
298        let reader = BufReader::new(response.into_body());
299
300        Ok(reader
301            .lines()
302            .map(|line| match line {
303                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
304                Err(e) => Err(e.into()),
305            })
306            .boxed())
307    } else {
308        let mut body = String::new();
309        response.body_mut().read_to_string(&mut body).await?;
310        anyhow::bail!(
311            "Failed to connect to Ollama API: {} {}",
312            response.status(),
313            body,
314        );
315    }
316}
317
318pub async fn get_models(
319    client: &dyn HttpClient,
320    api_url: &str,
321    api_key: Option<&str>,
322    _: Option<Duration>,
323) -> Result<Vec<LocalModelListing>> {
324    let uri = format!("{api_url}/api/tags");
325    let mut request_builder = HttpRequest::builder()
326        .method(Method::GET)
327        .uri(uri)
328        .header("Accept", "application/json");
329
330    if let Some(api_key) = api_key {
331        request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
332    }
333
334    let request = request_builder.body(AsyncBody::default())?;
335
336    let mut response = client.send(request).await?;
337
338    let mut body = String::new();
339    response.body_mut().read_to_string(&mut body).await?;
340
341    anyhow::ensure!(
342        response.status().is_success(),
343        "Failed to connect to Ollama API: {} {}",
344        response.status(),
345        body,
346    );
347    let response: LocalModelsResponse =
348        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
349    Ok(response.models)
350}
351
352/// Fetch details of a model, used to determine model capabilities
353pub async fn show_model(
354    client: &dyn HttpClient,
355    api_url: &str,
356    api_key: Option<&str>,
357    model: &str,
358) -> Result<ModelShow> {
359    let uri = format!("{api_url}/api/show");
360    let mut request_builder = HttpRequest::builder()
361        .method(Method::POST)
362        .uri(uri)
363        .header("Content-Type", "application/json");
364
365    if let Some(api_key) = api_key {
366        request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
367    }
368
369    let request = request_builder.body(AsyncBody::from(
370        serde_json::json!({ "model": model }).to_string(),
371    ))?;
372
373    let mut response = client.send(request).await?;
374    let mut body = String::new();
375    response.body_mut().read_to_string(&mut body).await?;
376
377    anyhow::ensure!(
378        response.status().is_success(),
379        "Failed to connect to Ollama API: {} {}",
380        response.status(),
381        body,
382    );
383    let details: ModelShow = serde_json::from_str(body.as_str())?;
384    Ok(details)
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn parse_completion() {
393        let response = serde_json::json!({
394        "model": "llama3.2",
395        "created_at": "2023-12-12T14:13:43.416799Z",
396        "message": {
397            "role": "assistant",
398            "content": "Hello! How are you today?"
399        },
400        "done": true,
401        "total_duration": 5191566416u64,
402        "load_duration": 2154458,
403        "prompt_eval_count": 26,
404        "prompt_eval_duration": 383809000,
405        "eval_count": 298,
406        "eval_duration": 4799921000u64
407        });
408        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
409    }
410
411    #[test]
412    fn parse_streaming_completion() {
413        let partial = serde_json::json!({
414        "model": "llama3.2",
415        "created_at": "2023-08-04T08:52:19.385406455-07:00",
416        "message": {
417            "role": "assistant",
418            "content": "The",
419            "images": null
420        },
421        "done": false
422        });
423
424        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
425
426        let last = serde_json::json!({
427        "model": "llama3.2",
428        "created_at": "2023-08-04T19:22:45.499127Z",
429        "message": {
430            "role": "assistant",
431            "content": ""
432        },
433        "done": true,
434        "total_duration": 4883583458u64,
435        "load_duration": 1334875,
436        "prompt_eval_count": 26,
437        "prompt_eval_duration": 342546000,
438        "eval_count": 282,
439        "eval_duration": 4535599000u64
440        });
441
442        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
443    }
444
445    #[test]
446    fn parse_tool_call() {
447        let response = serde_json::json!({
448            "model": "llama3.2:3b",
449            "created_at": "2025-04-28T20:02:02.140489Z",
450            "message": {
451                "role": "assistant",
452                "content": "",
453                "tool_calls": [
454                    {
455                        "function": {
456                            "name": "weather",
457                            "arguments": {
458                                "city": "london",
459                            }
460                        }
461                    }
462                ]
463            },
464            "done_reason": "stop",
465            "done": true,
466            "total_duration": 2758629166u64,
467            "load_duration": 1770059875,
468            "prompt_eval_count": 147,
469            "prompt_eval_duration": 684637583,
470            "eval_count": 16,
471            "eval_duration": 302561917,
472        });
473
474        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
475        match result.message {
476            ChatMessage::Assistant {
477                content,
478                tool_calls,
479                images: _,
480                thinking,
481            } => {
482                assert!(content.is_empty());
483                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
484                assert!(thinking.is_none());
485            }
486            _ => panic!("Deserialized wrong role"),
487        }
488    }
489
490    #[test]
491    fn parse_show_model() {
492        let response = serde_json::json!({
493            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
494            "details": {
495                "parent_model": "",
496                "format": "gguf",
497                "family": "llama",
498                "families": ["llama"],
499                "parameter_size": "3.2B",
500                "quantization_level": "Q4_K_M"
501            },
502            "model_info": {
503                "general.architecture": "llama",
504                "general.basename": "Llama-3.2",
505                "general.file_type": 15,
506                "general.finetune": "Instruct",
507                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
508                "general.parameter_count": 3212749888u64,
509                "general.quantization_version": 2,
510                "general.size_label": "3B",
511                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
512                "general.type": "model",
513                "llama.attention.head_count": 24,
514                "llama.attention.head_count_kv": 8,
515                "llama.attention.key_length": 128,
516                "llama.attention.layer_norm_rms_epsilon": 0.00001,
517                "llama.attention.value_length": 128,
518                "llama.block_count": 28,
519                "llama.context_length": 131072,
520                "llama.embedding_length": 3072,
521                "llama.feed_forward_length": 8192,
522                "llama.rope.dimension_count": 128,
523                "llama.rope.freq_base": 500000,
524                "llama.vocab_size": 128256,
525                "tokenizer.ggml.bos_token_id": 128000,
526                "tokenizer.ggml.eos_token_id": 128009,
527                "tokenizer.ggml.merges": null,
528                "tokenizer.ggml.model": "gpt2",
529                "tokenizer.ggml.pre": "llama-bpe",
530                "tokenizer.ggml.token_type": null,
531                "tokenizer.ggml.tokens": null
532            },
533            "tensors": [
534                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
535                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
536            ],
537            "capabilities": ["completion", "tools"],
538            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
539        });
540
541        let result: ModelShow = serde_json::from_value(response).unwrap();
542        assert!(result.supports_tools());
543        assert!(result.capabilities.contains(&"tools".to_string()));
544        assert!(result.capabilities.contains(&"completion".to_string()));
545    }
546
547    #[test]
548    fn serialize_chat_request_with_images() {
549        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
550
551        let request = ChatRequest {
552            model: "llava".to_string(),
553            messages: vec![ChatMessage::User {
554                content: "What do you see in this image?".to_string(),
555                images: Some(vec![base64_image.to_string()]),
556            }],
557            stream: false,
558            keep_alive: KeepAlive::default(),
559            options: None,
560            think: None,
561            tools: vec![],
562        };
563
564        let serialized = serde_json::to_string(&request).unwrap();
565        assert!(serialized.contains("images"));
566        assert!(serialized.contains(base64_image));
567    }
568
569    #[test]
570    fn serialize_chat_request_without_images() {
571        let request = ChatRequest {
572            model: "llama3.2".to_string(),
573            messages: vec![ChatMessage::User {
574                content: "Hello, world!".to_string(),
575                images: None,
576            }],
577            stream: false,
578            keep_alive: KeepAlive::default(),
579            options: None,
580            think: None,
581            tools: vec![],
582        };
583
584        let serialized = serde_json::to_string(&request).unwrap();
585        assert!(!serialized.contains("images"));
586    }
587
588    #[test]
589    fn test_json_format_with_images() {
590        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
591
592        let request = ChatRequest {
593            model: "llava".to_string(),
594            messages: vec![ChatMessage::User {
595                content: "What do you see?".to_string(),
596                images: Some(vec![base64_image.to_string()]),
597            }],
598            stream: false,
599            keep_alive: KeepAlive::default(),
600            options: None,
601            think: None,
602            tools: vec![],
603        };
604
605        let serialized = serde_json::to_string(&request).unwrap();
606
607        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
608        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
609        assert_eq!(message_images.len(), 1);
610        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
611    }
612}