ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{sync::Arc, time::Duration};
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(untagged)]
 13pub enum KeepAlive {
 14    /// Keep model alive for N seconds
 15    Seconds(isize),
 16    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 17    Duration(String),
 18}
 19
 20impl KeepAlive {
 21    /// Keep model alive until a new model is loaded or until Ollama shuts down
 22    fn indefinite() -> Self {
 23        Self::Seconds(-1)
 24    }
 25}
 26
 27impl Default for KeepAlive {
 28    fn default() -> Self {
 29        Self::indefinite()
 30    }
 31}
 32
 33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 35pub struct Model {
 36    pub name: String,
 37    pub display_name: Option<String>,
 38    pub max_tokens: usize,
 39    pub keep_alive: Option<KeepAlive>,
 40    pub supports_tools: Option<bool>,
 41    pub supports_vision: Option<bool>,
 42    pub supports_thinking: Option<bool>,
 43}
 44
 45fn get_max_tokens(name: &str) -> usize {
 46    /// Default context length for unknown models.
 47    const DEFAULT_TOKENS: usize = 4096;
 48    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 49    const MAXIMUM_TOKENS: usize = 16384;
 50
 51    match name.split(':').next().unwrap() {
 52        "phi" | "tinyllama" | "granite-code" => 2048,
 53        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 54        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 55        "codellama" | "starcoder2" => 16384,
 56        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
 57        | "dolphin-mixtral" => 32768,
 58        "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
 59        | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
 60        | "devstral" => 128000,
 61        _ => DEFAULT_TOKENS,
 62    }
 63    .clamp(1, MAXIMUM_TOKENS)
 64}
 65
 66impl Model {
 67    pub fn new(
 68        name: &str,
 69        display_name: Option<&str>,
 70        max_tokens: Option<usize>,
 71        supports_tools: Option<bool>,
 72        supports_vision: Option<bool>,
 73        supports_thinking: Option<bool>,
 74    ) -> Self {
 75        Self {
 76            name: name.to_owned(),
 77            display_name: display_name
 78                .map(ToString::to_string)
 79                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 80            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 81            keep_alive: Some(KeepAlive::indefinite()),
 82            supports_tools,
 83            supports_vision,
 84            supports_thinking,
 85        }
 86    }
 87
 88    pub fn id(&self) -> &str {
 89        &self.name
 90    }
 91
 92    pub fn display_name(&self) -> &str {
 93        self.display_name.as_ref().unwrap_or(&self.name)
 94    }
 95
 96    pub fn max_token_count(&self) -> usize {
 97        self.max_tokens
 98    }
 99}
100
101#[derive(Serialize, Deserialize, Debug)]
102#[serde(tag = "role", rename_all = "lowercase")]
103pub enum ChatMessage {
104    Assistant {
105        content: String,
106        tool_calls: Option<Vec<OllamaToolCall>>,
107        #[serde(skip_serializing_if = "Option::is_none")]
108        images: Option<Vec<String>>,
109        thinking: Option<String>,
110    },
111    User {
112        content: String,
113        #[serde(skip_serializing_if = "Option::is_none")]
114        images: Option<Vec<String>>,
115    },
116    System {
117        content: String,
118    },
119}
120
121#[derive(Serialize, Deserialize, Debug)]
122#[serde(rename_all = "lowercase")]
123pub enum OllamaToolCall {
124    Function(OllamaFunctionCall),
125}
126
127#[derive(Serialize, Deserialize, Debug)]
128pub struct OllamaFunctionCall {
129    pub name: String,
130    pub arguments: Value,
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134pub struct OllamaFunctionTool {
135    pub name: String,
136    pub description: Option<String>,
137    pub parameters: Option<Value>,
138}
139
140#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
141#[serde(tag = "type", rename_all = "lowercase")]
142pub enum OllamaTool {
143    Function { function: OllamaFunctionTool },
144}
145
146#[derive(Serialize, Debug)]
147pub struct ChatRequest {
148    pub model: String,
149    pub messages: Vec<ChatMessage>,
150    pub stream: bool,
151    pub keep_alive: KeepAlive,
152    pub options: Option<ChatOptions>,
153    pub tools: Vec<OllamaTool>,
154    pub think: Option<bool>,
155}
156
157impl ChatRequest {
158    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
159        self.stream = false;
160        self.tools = tools;
161        self
162    }
163}
164
165// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
166#[derive(Serialize, Default, Debug)]
167pub struct ChatOptions {
168    pub num_ctx: Option<usize>,
169    pub num_predict: Option<isize>,
170    pub stop: Option<Vec<String>>,
171    pub temperature: Option<f32>,
172    pub top_p: Option<f32>,
173}
174
175#[derive(Deserialize, Debug)]
176pub struct ChatResponseDelta {
177    #[allow(unused)]
178    pub model: String,
179    #[allow(unused)]
180    pub created_at: String,
181    pub message: ChatMessage,
182    #[allow(unused)]
183    pub done_reason: Option<String>,
184    #[allow(unused)]
185    pub done: bool,
186}
187
188#[derive(Serialize, Deserialize)]
189pub struct LocalModelsResponse {
190    pub models: Vec<LocalModelListing>,
191}
192
193#[derive(Serialize, Deserialize)]
194pub struct LocalModelListing {
195    pub name: String,
196    pub modified_at: String,
197    pub size: u64,
198    pub digest: String,
199    pub details: ModelDetails,
200}
201
202#[derive(Serialize, Deserialize)]
203pub struct LocalModel {
204    pub modelfile: String,
205    pub parameters: String,
206    pub template: String,
207    pub details: ModelDetails,
208}
209
210#[derive(Serialize, Deserialize)]
211pub struct ModelDetails {
212    pub format: String,
213    pub family: String,
214    pub families: Option<Vec<String>>,
215    pub parameter_size: String,
216    pub quantization_level: String,
217}
218
219#[derive(Deserialize, Debug)]
220pub struct ModelShow {
221    #[serde(default)]
222    pub capabilities: Vec<String>,
223}
224
225impl ModelShow {
226    pub fn supports_tools(&self) -> bool {
227        // .contains expects &String, which would require an additional allocation
228        self.capabilities.iter().any(|v| v == "tools")
229    }
230
231    pub fn supports_vision(&self) -> bool {
232        self.capabilities.iter().any(|v| v == "vision")
233    }
234
235    pub fn supports_thinking(&self) -> bool {
236        self.capabilities.iter().any(|v| v == "thinking")
237    }
238}
239
240pub async fn complete(
241    client: &dyn HttpClient,
242    api_url: &str,
243    request: ChatRequest,
244) -> Result<ChatResponseDelta> {
245    let uri = format!("{api_url}/api/chat");
246    let request_builder = HttpRequest::builder()
247        .method(Method::POST)
248        .uri(uri)
249        .header("Content-Type", "application/json");
250
251    let serialized_request = serde_json::to_string(&request)?;
252    let request = request_builder.body(AsyncBody::from(serialized_request))?;
253
254    let mut response = client.send(request).await?;
255
256    let mut body = Vec::new();
257    response.body_mut().read_to_end(&mut body).await?;
258
259    if response.status().is_success() {
260        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
261        Ok(response_message)
262    } else {
263        let body_str = std::str::from_utf8(&body)?;
264        anyhow::bail!(
265            "Failed to connect to API: {} {}",
266            response.status(),
267            body_str
268        );
269    }
270}
271
272pub async fn stream_chat_completion(
273    client: &dyn HttpClient,
274    api_url: &str,
275    request: ChatRequest,
276) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
277    let uri = format!("{api_url}/api/chat");
278    let request_builder = http::Request::builder()
279        .method(Method::POST)
280        .uri(uri)
281        .header("Content-Type", "application/json");
282
283    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
284    let mut response = client.send(request).await?;
285    if response.status().is_success() {
286        let reader = BufReader::new(response.into_body());
287
288        Ok(reader
289            .lines()
290            .map(|line| match line {
291                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
292                Err(e) => Err(e.into()),
293            })
294            .boxed())
295    } else {
296        let mut body = String::new();
297        response.body_mut().read_to_string(&mut body).await?;
298        anyhow::bail!(
299            "Failed to connect to Ollama API: {} {}",
300            response.status(),
301            body,
302        );
303    }
304}
305
306pub async fn get_models(
307    client: &dyn HttpClient,
308    api_url: &str,
309    _: Option<Duration>,
310) -> Result<Vec<LocalModelListing>> {
311    let uri = format!("{api_url}/api/tags");
312    let request_builder = HttpRequest::builder()
313        .method(Method::GET)
314        .uri(uri)
315        .header("Accept", "application/json");
316
317    let request = request_builder.body(AsyncBody::default())?;
318
319    let mut response = client.send(request).await?;
320
321    let mut body = String::new();
322    response.body_mut().read_to_string(&mut body).await?;
323
324    anyhow::ensure!(
325        response.status().is_success(),
326        "Failed to connect to Ollama API: {} {}",
327        response.status(),
328        body,
329    );
330    let response: LocalModelsResponse =
331        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
332    Ok(response.models)
333}
334
335/// Fetch details of a model, used to determine model capabilities
336pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
337    let uri = format!("{api_url}/api/show");
338    let request = HttpRequest::builder()
339        .method(Method::POST)
340        .uri(uri)
341        .header("Content-Type", "application/json")
342        .body(AsyncBody::from(
343            serde_json::json!({ "model": model }).to_string(),
344        ))?;
345
346    let mut response = client.send(request).await?;
347    let mut body = String::new();
348    response.body_mut().read_to_string(&mut body).await?;
349
350    anyhow::ensure!(
351        response.status().is_success(),
352        "Failed to connect to Ollama API: {} {}",
353        response.status(),
354        body,
355    );
356    let details: ModelShow = serde_json::from_str(body.as_str())?;
357    Ok(details)
358}
359
360/// Sends an empty request to Ollama to trigger loading the model
361pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
362    let uri = format!("{api_url}/api/generate");
363    let request = HttpRequest::builder()
364        .method(Method::POST)
365        .uri(uri)
366        .header("Content-Type", "application/json")
367        .body(AsyncBody::from(
368            serde_json::json!({
369                "model": model,
370                "keep_alive": "15m",
371            })
372            .to_string(),
373        ))?;
374
375    let mut response = client.send(request).await?;
376
377    if response.status().is_success() {
378        Ok(())
379    } else {
380        let mut body = String::new();
381        response.body_mut().read_to_string(&mut body).await?;
382        anyhow::bail!(
383            "Failed to connect to Ollama API: {} {}",
384            response.status(),
385            body,
386        );
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn parse_completion() {
396        let response = serde_json::json!({
397        "model": "llama3.2",
398        "created_at": "2023-12-12T14:13:43.416799Z",
399        "message": {
400            "role": "assistant",
401            "content": "Hello! How are you today?"
402        },
403        "done": true,
404        "total_duration": 5191566416u64,
405        "load_duration": 2154458,
406        "prompt_eval_count": 26,
407        "prompt_eval_duration": 383809000,
408        "eval_count": 298,
409        "eval_duration": 4799921000u64
410        });
411        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
412    }
413
414    #[test]
415    fn parse_streaming_completion() {
416        let partial = serde_json::json!({
417        "model": "llama3.2",
418        "created_at": "2023-08-04T08:52:19.385406455-07:00",
419        "message": {
420            "role": "assistant",
421            "content": "The",
422            "images": null
423        },
424        "done": false
425        });
426
427        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
428
429        let last = serde_json::json!({
430        "model": "llama3.2",
431        "created_at": "2023-08-04T19:22:45.499127Z",
432        "message": {
433            "role": "assistant",
434            "content": ""
435        },
436        "done": true,
437        "total_duration": 4883583458u64,
438        "load_duration": 1334875,
439        "prompt_eval_count": 26,
440        "prompt_eval_duration": 342546000,
441        "eval_count": 282,
442        "eval_duration": 4535599000u64
443        });
444
445        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
446    }
447
448    #[test]
449    fn parse_tool_call() {
450        let response = serde_json::json!({
451            "model": "llama3.2:3b",
452            "created_at": "2025-04-28T20:02:02.140489Z",
453            "message": {
454                "role": "assistant",
455                "content": "",
456                "tool_calls": [
457                    {
458                        "function": {
459                            "name": "weather",
460                            "arguments": {
461                                "city": "london",
462                            }
463                        }
464                    }
465                ]
466            },
467            "done_reason": "stop",
468            "done": true,
469            "total_duration": 2758629166u64,
470            "load_duration": 1770059875,
471            "prompt_eval_count": 147,
472            "prompt_eval_duration": 684637583,
473            "eval_count": 16,
474            "eval_duration": 302561917,
475        });
476
477        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
478        match result.message {
479            ChatMessage::Assistant {
480                content,
481                tool_calls,
482                images: _,
483                thinking,
484            } => {
485                assert!(content.is_empty());
486                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
487                assert!(thinking.is_none());
488            }
489            _ => panic!("Deserialized wrong role"),
490        }
491    }
492
493    #[test]
494    fn parse_show_model() {
495        let response = serde_json::json!({
496            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
497            "details": {
498                "parent_model": "",
499                "format": "gguf",
500                "family": "llama",
501                "families": ["llama"],
502                "parameter_size": "3.2B",
503                "quantization_level": "Q4_K_M"
504            },
505            "model_info": {
506                "general.architecture": "llama",
507                "general.basename": "Llama-3.2",
508                "general.file_type": 15,
509                "general.finetune": "Instruct",
510                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
511                "general.parameter_count": 3212749888u64,
512                "general.quantization_version": 2,
513                "general.size_label": "3B",
514                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
515                "general.type": "model",
516                "llama.attention.head_count": 24,
517                "llama.attention.head_count_kv": 8,
518                "llama.attention.key_length": 128,
519                "llama.attention.layer_norm_rms_epsilon": 0.00001,
520                "llama.attention.value_length": 128,
521                "llama.block_count": 28,
522                "llama.context_length": 131072,
523                "llama.embedding_length": 3072,
524                "llama.feed_forward_length": 8192,
525                "llama.rope.dimension_count": 128,
526                "llama.rope.freq_base": 500000,
527                "llama.vocab_size": 128256,
528                "tokenizer.ggml.bos_token_id": 128000,
529                "tokenizer.ggml.eos_token_id": 128009,
530                "tokenizer.ggml.merges": null,
531                "tokenizer.ggml.model": "gpt2",
532                "tokenizer.ggml.pre": "llama-bpe",
533                "tokenizer.ggml.token_type": null,
534                "tokenizer.ggml.tokens": null
535            },
536            "tensors": [
537                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
538                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
539            ],
540            "capabilities": ["completion", "tools"],
541            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
542        });
543
544        let result: ModelShow = serde_json::from_value(response).unwrap();
545        assert!(result.supports_tools());
546        assert!(result.capabilities.contains(&"tools".to_string()));
547        assert!(result.capabilities.contains(&"completion".to_string()));
548    }
549
550    #[test]
551    fn serialize_chat_request_with_images() {
552        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
553
554        let request = ChatRequest {
555            model: "llava".to_string(),
556            messages: vec![ChatMessage::User {
557                content: "What do you see in this image?".to_string(),
558                images: Some(vec![base64_image.to_string()]),
559            }],
560            stream: false,
561            keep_alive: KeepAlive::default(),
562            options: None,
563            think: None,
564            tools: vec![],
565        };
566
567        let serialized = serde_json::to_string(&request).unwrap();
568        assert!(serialized.contains("images"));
569        assert!(serialized.contains(base64_image));
570    }
571
572    #[test]
573    fn serialize_chat_request_without_images() {
574        let request = ChatRequest {
575            model: "llama3.2".to_string(),
576            messages: vec![ChatMessage::User {
577                content: "Hello, world!".to_string(),
578                images: None,
579            }],
580            stream: false,
581            keep_alive: KeepAlive::default(),
582            options: None,
583            think: None,
584            tools: vec![],
585        };
586
587        let serialized = serde_json::to_string(&request).unwrap();
588        assert!(!serialized.contains("images"));
589    }
590
591    #[test]
592    fn test_json_format_with_images() {
593        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
594
595        let request = ChatRequest {
596            model: "llava".to_string(),
597            messages: vec![ChatMessage::User {
598                content: "What do you see?".to_string(),
599                images: Some(vec![base64_image.to_string()]),
600            }],
601            stream: false,
602            keep_alive: KeepAlive::default(),
603            options: None,
604            think: None,
605            tools: vec![],
606        };
607
608        let serialized = serde_json::to_string(&request).unwrap();
609
610        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
611        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
612        assert_eq!(message_images.len(), 1);
613        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
614    }
615}