ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::time::Duration;
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(untagged)]
 13pub enum KeepAlive {
 14    /// Keep model alive for N seconds
 15    Seconds(isize),
 16    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 17    Duration(String),
 18}
 19
 20impl KeepAlive {
 21    /// Keep model alive until a new model is loaded or until Ollama shuts down
 22    fn indefinite() -> Self {
 23        Self::Seconds(-1)
 24    }
 25}
 26
 27impl Default for KeepAlive {
 28    fn default() -> Self {
 29        Self::indefinite()
 30    }
 31}
 32
 33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 35pub struct Model {
 36    pub name: String,
 37    pub display_name: Option<String>,
 38    pub max_tokens: u64,
 39    pub keep_alive: Option<KeepAlive>,
 40    pub supports_tools: Option<bool>,
 41    pub supports_vision: Option<bool>,
 42    pub supports_thinking: Option<bool>,
 43}
 44
 45fn get_max_tokens(name: &str) -> u64 {
 46    /// Default context length for unknown models.
 47    const DEFAULT_TOKENS: u64 = 4096;
 48    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 49    const MAXIMUM_TOKENS: u64 = 16384;
 50
 51    match name.split(':').next().unwrap() {
 52        "phi" | "tinyllama" | "granite-code" => 2048,
 53        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 54        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 55        "codellama" | "starcoder2" => 16384,
 56        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
 57        | "dolphin-mixtral" => 32768,
 58        "magistral" => 40000,
 59        "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
 60        | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
 61        | "devstral" | "gpt-oss" => 128000,
 62        _ => DEFAULT_TOKENS,
 63    }
 64    .clamp(1, MAXIMUM_TOKENS)
 65}
 66
 67impl Model {
 68    pub fn new(
 69        name: &str,
 70        display_name: Option<&str>,
 71        max_tokens: Option<u64>,
 72        supports_tools: Option<bool>,
 73        supports_vision: Option<bool>,
 74        supports_thinking: Option<bool>,
 75    ) -> Self {
 76        Self {
 77            name: name.to_owned(),
 78            display_name: display_name
 79                .map(ToString::to_string)
 80                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 81            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 82            keep_alive: Some(KeepAlive::indefinite()),
 83            supports_tools,
 84            supports_vision,
 85            supports_thinking,
 86        }
 87    }
 88
 89    pub fn id(&self) -> &str {
 90        &self.name
 91    }
 92
 93    pub fn display_name(&self) -> &str {
 94        self.display_name.as_ref().unwrap_or(&self.name)
 95    }
 96
 97    pub fn max_token_count(&self) -> u64 {
 98        self.max_tokens
 99    }
100}
101
102#[derive(Serialize, Deserialize, Debug)]
103#[serde(tag = "role", rename_all = "lowercase")]
104pub enum ChatMessage {
105    Assistant {
106        content: String,
107        tool_calls: Option<Vec<OllamaToolCall>>,
108        #[serde(skip_serializing_if = "Option::is_none")]
109        images: Option<Vec<String>>,
110        thinking: Option<String>,
111    },
112    User {
113        content: String,
114        #[serde(skip_serializing_if = "Option::is_none")]
115        images: Option<Vec<String>>,
116    },
117    System {
118        content: String,
119    },
120}
121
122#[derive(Serialize, Deserialize, Debug)]
123#[serde(rename_all = "lowercase")]
124pub enum OllamaToolCall {
125    Function(OllamaFunctionCall),
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct OllamaFunctionCall {
130    pub name: String,
131    pub arguments: Value,
132}
133
134#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
135pub struct OllamaFunctionTool {
136    pub name: String,
137    pub description: Option<String>,
138    pub parameters: Option<Value>,
139}
140
141#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
142#[serde(tag = "type", rename_all = "lowercase")]
143pub enum OllamaTool {
144    Function { function: OllamaFunctionTool },
145}
146
147#[derive(Serialize, Debug)]
148pub struct ChatRequest {
149    pub model: String,
150    pub messages: Vec<ChatMessage>,
151    pub stream: bool,
152    pub keep_alive: KeepAlive,
153    pub options: Option<ChatOptions>,
154    pub tools: Vec<OllamaTool>,
155    pub think: Option<bool>,
156}
157
158impl ChatRequest {
159    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
160        self.stream = false;
161        self.tools = tools;
162        self
163    }
164}
165
166// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
167#[derive(Serialize, Default, Debug)]
168pub struct ChatOptions {
169    pub num_ctx: Option<u64>,
170    pub num_predict: Option<isize>,
171    pub stop: Option<Vec<String>>,
172    pub temperature: Option<f32>,
173    pub top_p: Option<f32>,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct ChatResponseDelta {
178    #[allow(unused)]
179    pub model: String,
180    #[allow(unused)]
181    pub created_at: String,
182    pub message: ChatMessage,
183    #[allow(unused)]
184    pub done_reason: Option<String>,
185    #[allow(unused)]
186    pub done: bool,
187    pub prompt_eval_count: Option<u64>,
188    pub eval_count: Option<u64>,
189}
190
191#[derive(Serialize, Deserialize)]
192pub struct LocalModelsResponse {
193    pub models: Vec<LocalModelListing>,
194}
195
196#[derive(Serialize, Deserialize)]
197pub struct LocalModelListing {
198    pub name: String,
199    pub modified_at: String,
200    pub size: u64,
201    pub digest: String,
202    pub details: ModelDetails,
203}
204
205#[derive(Serialize, Deserialize)]
206pub struct LocalModel {
207    pub modelfile: String,
208    pub parameters: String,
209    pub template: String,
210    pub details: ModelDetails,
211}
212
213#[derive(Serialize, Deserialize)]
214pub struct ModelDetails {
215    pub format: String,
216    pub family: String,
217    pub families: Option<Vec<String>>,
218    pub parameter_size: String,
219    pub quantization_level: String,
220}
221
222#[derive(Deserialize, Debug)]
223pub struct ModelShow {
224    #[serde(default)]
225    pub capabilities: Vec<String>,
226}
227
228impl ModelShow {
229    pub fn supports_tools(&self) -> bool {
230        // .contains expects &String, which would require an additional allocation
231        self.capabilities.iter().any(|v| v == "tools")
232    }
233
234    pub fn supports_vision(&self) -> bool {
235        self.capabilities.iter().any(|v| v == "vision")
236    }
237
238    pub fn supports_thinking(&self) -> bool {
239        self.capabilities.iter().any(|v| v == "thinking")
240    }
241}
242
243pub async fn complete(
244    client: &dyn HttpClient,
245    api_url: &str,
246    request: ChatRequest,
247) -> Result<ChatResponseDelta> {
248    let uri = format!("{api_url}/api/chat");
249    let request_builder = HttpRequest::builder()
250        .method(Method::POST)
251        .uri(uri)
252        .header("Content-Type", "application/json");
253
254    let serialized_request = serde_json::to_string(&request)?;
255    let request = request_builder.body(AsyncBody::from(serialized_request))?;
256
257    let mut response = client.send(request).await?;
258
259    let mut body = Vec::new();
260    response.body_mut().read_to_end(&mut body).await?;
261
262    if response.status().is_success() {
263        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
264        Ok(response_message)
265    } else {
266        let body_str = std::str::from_utf8(&body)?;
267        anyhow::bail!(
268            "Failed to connect to API: {} {}",
269            response.status(),
270            body_str
271        );
272    }
273}
274
275pub async fn stream_chat_completion(
276    client: &dyn HttpClient,
277    api_url: &str,
278    request: ChatRequest,
279) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
280    let uri = format!("{api_url}/api/chat");
281    let request_builder = http::Request::builder()
282        .method(Method::POST)
283        .uri(uri)
284        .header("Content-Type", "application/json");
285
286    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
287    let mut response = client.send(request).await?;
288    if response.status().is_success() {
289        let reader = BufReader::new(response.into_body());
290
291        Ok(reader
292            .lines()
293            .map(|line| match line {
294                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
295                Err(e) => Err(e.into()),
296            })
297            .boxed())
298    } else {
299        let mut body = String::new();
300        response.body_mut().read_to_string(&mut body).await?;
301        anyhow::bail!(
302            "Failed to connect to Ollama API: {} {}",
303            response.status(),
304            body,
305        );
306    }
307}
308
309pub async fn get_models(
310    client: &dyn HttpClient,
311    api_url: &str,
312    _: Option<Duration>,
313) -> Result<Vec<LocalModelListing>> {
314    let uri = format!("{api_url}/api/tags");
315    let request_builder = HttpRequest::builder()
316        .method(Method::GET)
317        .uri(uri)
318        .header("Accept", "application/json");
319
320    let request = request_builder.body(AsyncBody::default())?;
321
322    let mut response = client.send(request).await?;
323
324    let mut body = String::new();
325    response.body_mut().read_to_string(&mut body).await?;
326
327    anyhow::ensure!(
328        response.status().is_success(),
329        "Failed to connect to Ollama API: {} {}",
330        response.status(),
331        body,
332    );
333    let response: LocalModelsResponse =
334        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
335    Ok(response.models)
336}
337
338/// Fetch details of a model, used to determine model capabilities
339pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
340    let uri = format!("{api_url}/api/show");
341    let request = HttpRequest::builder()
342        .method(Method::POST)
343        .uri(uri)
344        .header("Content-Type", "application/json")
345        .body(AsyncBody::from(
346            serde_json::json!({ "model": model }).to_string(),
347        ))?;
348
349    let mut response = client.send(request).await?;
350    let mut body = String::new();
351    response.body_mut().read_to_string(&mut body).await?;
352
353    anyhow::ensure!(
354        response.status().is_success(),
355        "Failed to connect to Ollama API: {} {}",
356        response.status(),
357        body,
358    );
359    let details: ModelShow = serde_json::from_str(body.as_str())?;
360    Ok(details)
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn parse_completion() {
369        let response = serde_json::json!({
370        "model": "llama3.2",
371        "created_at": "2023-12-12T14:13:43.416799Z",
372        "message": {
373            "role": "assistant",
374            "content": "Hello! How are you today?"
375        },
376        "done": true,
377        "total_duration": 5191566416u64,
378        "load_duration": 2154458,
379        "prompt_eval_count": 26,
380        "prompt_eval_duration": 383809000,
381        "eval_count": 298,
382        "eval_duration": 4799921000u64
383        });
384        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
385    }
386
387    #[test]
388    fn parse_streaming_completion() {
389        let partial = serde_json::json!({
390        "model": "llama3.2",
391        "created_at": "2023-08-04T08:52:19.385406455-07:00",
392        "message": {
393            "role": "assistant",
394            "content": "The",
395            "images": null
396        },
397        "done": false
398        });
399
400        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
401
402        let last = serde_json::json!({
403        "model": "llama3.2",
404        "created_at": "2023-08-04T19:22:45.499127Z",
405        "message": {
406            "role": "assistant",
407            "content": ""
408        },
409        "done": true,
410        "total_duration": 4883583458u64,
411        "load_duration": 1334875,
412        "prompt_eval_count": 26,
413        "prompt_eval_duration": 342546000,
414        "eval_count": 282,
415        "eval_duration": 4535599000u64
416        });
417
418        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
419    }
420
421    #[test]
422    fn parse_tool_call() {
423        let response = serde_json::json!({
424            "model": "llama3.2:3b",
425            "created_at": "2025-04-28T20:02:02.140489Z",
426            "message": {
427                "role": "assistant",
428                "content": "",
429                "tool_calls": [
430                    {
431                        "function": {
432                            "name": "weather",
433                            "arguments": {
434                                "city": "london",
435                            }
436                        }
437                    }
438                ]
439            },
440            "done_reason": "stop",
441            "done": true,
442            "total_duration": 2758629166u64,
443            "load_duration": 1770059875,
444            "prompt_eval_count": 147,
445            "prompt_eval_duration": 684637583,
446            "eval_count": 16,
447            "eval_duration": 302561917,
448        });
449
450        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
451        match result.message {
452            ChatMessage::Assistant {
453                content,
454                tool_calls,
455                images: _,
456                thinking,
457            } => {
458                assert!(content.is_empty());
459                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
460                assert!(thinking.is_none());
461            }
462            _ => panic!("Deserialized wrong role"),
463        }
464    }
465
466    #[test]
467    fn parse_show_model() {
468        let response = serde_json::json!({
469            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
470            "details": {
471                "parent_model": "",
472                "format": "gguf",
473                "family": "llama",
474                "families": ["llama"],
475                "parameter_size": "3.2B",
476                "quantization_level": "Q4_K_M"
477            },
478            "model_info": {
479                "general.architecture": "llama",
480                "general.basename": "Llama-3.2",
481                "general.file_type": 15,
482                "general.finetune": "Instruct",
483                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
484                "general.parameter_count": 3212749888u64,
485                "general.quantization_version": 2,
486                "general.size_label": "3B",
487                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
488                "general.type": "model",
489                "llama.attention.head_count": 24,
490                "llama.attention.head_count_kv": 8,
491                "llama.attention.key_length": 128,
492                "llama.attention.layer_norm_rms_epsilon": 0.00001,
493                "llama.attention.value_length": 128,
494                "llama.block_count": 28,
495                "llama.context_length": 131072,
496                "llama.embedding_length": 3072,
497                "llama.feed_forward_length": 8192,
498                "llama.rope.dimension_count": 128,
499                "llama.rope.freq_base": 500000,
500                "llama.vocab_size": 128256,
501                "tokenizer.ggml.bos_token_id": 128000,
502                "tokenizer.ggml.eos_token_id": 128009,
503                "tokenizer.ggml.merges": null,
504                "tokenizer.ggml.model": "gpt2",
505                "tokenizer.ggml.pre": "llama-bpe",
506                "tokenizer.ggml.token_type": null,
507                "tokenizer.ggml.tokens": null
508            },
509            "tensors": [
510                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
511                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
512            ],
513            "capabilities": ["completion", "tools"],
514            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
515        });
516
517        let result: ModelShow = serde_json::from_value(response).unwrap();
518        assert!(result.supports_tools());
519        assert!(result.capabilities.contains(&"tools".to_string()));
520        assert!(result.capabilities.contains(&"completion".to_string()));
521    }
522
523    #[test]
524    fn serialize_chat_request_with_images() {
525        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
526
527        let request = ChatRequest {
528            model: "llava".to_string(),
529            messages: vec![ChatMessage::User {
530                content: "What do you see in this image?".to_string(),
531                images: Some(vec![base64_image.to_string()]),
532            }],
533            stream: false,
534            keep_alive: KeepAlive::default(),
535            options: None,
536            think: None,
537            tools: vec![],
538        };
539
540        let serialized = serde_json::to_string(&request).unwrap();
541        assert!(serialized.contains("images"));
542        assert!(serialized.contains(base64_image));
543    }
544
545    #[test]
546    fn serialize_chat_request_without_images() {
547        let request = ChatRequest {
548            model: "llama3.2".to_string(),
549            messages: vec![ChatMessage::User {
550                content: "Hello, world!".to_string(),
551                images: None,
552            }],
553            stream: false,
554            keep_alive: KeepAlive::default(),
555            options: None,
556            think: None,
557            tools: vec![],
558        };
559
560        let serialized = serde_json::to_string(&request).unwrap();
561        assert!(!serialized.contains("images"));
562    }
563
564    #[test]
565    fn test_json_format_with_images() {
566        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
567
568        let request = ChatRequest {
569            model: "llava".to_string(),
570            messages: vec![ChatMessage::User {
571                content: "What do you see?".to_string(),
572                images: Some(vec![base64_image.to_string()]),
573            }],
574            stream: false,
575            keep_alive: KeepAlive::default(),
576            options: None,
577            think: None,
578            tools: vec![],
579        };
580
581        let serialized = serde_json::to_string(&request).unwrap();
582
583        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
584        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
585        assert_eq!(message_images.len(), 1);
586        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
587    }
588}