ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::KeepAlive;
  7use std::time::Duration;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(name: &str) -> u64 {
 24    /// Default context length for unknown models.
 25    const DEFAULT_TOKENS: u64 = 4096;
 26    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 27    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 28    const MAXIMUM_TOKENS: u64 = 16384;
 29
 30    match name.split(':').next().unwrap() {
 31        "granite-code" | "phi" | "tinyllama" => 2048,
 32        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 33        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 34        "codellama" | "starcoder2" => 16384,
 35        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 36        | "qwen2" | "qwen2.5-coder" => 32768,
 37        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 38        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 39        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 40        _ => DEFAULT_TOKENS,
 41    }
 42    .clamp(1, MAXIMUM_TOKENS)
 43}
 44
 45impl Model {
 46    pub fn new(
 47        name: &str,
 48        display_name: Option<&str>,
 49        max_tokens: Option<u64>,
 50        supports_tools: Option<bool>,
 51        supports_vision: Option<bool>,
 52        supports_thinking: Option<bool>,
 53    ) -> Self {
 54        Self {
 55            name: name.to_owned(),
 56            display_name: display_name
 57                .map(ToString::to_string)
 58                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 59            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 60            keep_alive: Some(KeepAlive::indefinite()),
 61            supports_tools,
 62            supports_vision,
 63            supports_thinking,
 64        }
 65    }
 66
 67    pub fn id(&self) -> &str {
 68        &self.name
 69    }
 70
 71    pub fn display_name(&self) -> &str {
 72        self.display_name.as_ref().unwrap_or(&self.name)
 73    }
 74
 75    pub fn max_token_count(&self) -> u64 {
 76        self.max_tokens
 77    }
 78}
 79
 80#[derive(Serialize, Deserialize, Debug)]
 81#[serde(tag = "role", rename_all = "lowercase")]
 82pub enum ChatMessage {
 83    Assistant {
 84        content: String,
 85        tool_calls: Option<Vec<OllamaToolCall>>,
 86        #[serde(skip_serializing_if = "Option::is_none")]
 87        images: Option<Vec<String>>,
 88        thinking: Option<String>,
 89    },
 90    User {
 91        content: String,
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        images: Option<Vec<String>>,
 94    },
 95    System {
 96        content: String,
 97    },
 98    Tool {
 99        tool_name: String,
100        content: String,
101    },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105#[serde(rename_all = "lowercase")]
106pub enum OllamaToolCall {
107    Function(OllamaFunctionCall),
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct OllamaFunctionCall {
112    pub name: String,
113    pub arguments: Value,
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionTool {
118    pub name: String,
119    pub description: Option<String>,
120    pub parameters: Option<Value>,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum OllamaTool {
126    Function { function: OllamaFunctionTool },
127}
128
129#[derive(Serialize, Debug)]
130pub struct ChatRequest {
131    pub model: String,
132    pub messages: Vec<ChatMessage>,
133    pub stream: bool,
134    pub keep_alive: KeepAlive,
135    pub options: Option<ChatOptions>,
136    pub tools: Vec<OllamaTool>,
137    pub think: Option<bool>,
138}
139
140impl ChatRequest {
141    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
142        self.stream = false;
143        self.tools = tools;
144        self
145    }
146}
147
148// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
149#[derive(Serialize, Default, Debug)]
150pub struct ChatOptions {
151    pub num_ctx: Option<u64>,
152    pub num_predict: Option<isize>,
153    pub stop: Option<Vec<String>>,
154    pub temperature: Option<f32>,
155    pub top_p: Option<f32>,
156}
157
158#[derive(Deserialize, Debug)]
159pub struct ChatResponseDelta {
160    #[allow(unused)]
161    pub model: String,
162    #[allow(unused)]
163    pub created_at: String,
164    pub message: ChatMessage,
165    #[allow(unused)]
166    pub done_reason: Option<String>,
167    #[allow(unused)]
168    pub done: bool,
169    pub prompt_eval_count: Option<u64>,
170    pub eval_count: Option<u64>,
171}
172
173#[derive(Serialize, Deserialize)]
174pub struct LocalModelsResponse {
175    pub models: Vec<LocalModelListing>,
176}
177
178#[derive(Serialize, Deserialize)]
179pub struct LocalModelListing {
180    pub name: String,
181    pub modified_at: String,
182    pub size: u64,
183    pub digest: String,
184    pub details: ModelDetails,
185}
186
187#[derive(Serialize, Deserialize)]
188pub struct LocalModel {
189    pub modelfile: String,
190    pub parameters: String,
191    pub template: String,
192    pub details: ModelDetails,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct ModelDetails {
197    pub format: String,
198    pub family: String,
199    pub families: Option<Vec<String>>,
200    pub parameter_size: String,
201    pub quantization_level: String,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct ModelShow {
206    #[serde(default)]
207    pub capabilities: Vec<String>,
208}
209
210impl ModelShow {
211    pub fn supports_tools(&self) -> bool {
212        // .contains expects &String, which would require an additional allocation
213        self.capabilities.iter().any(|v| v == "tools")
214    }
215
216    pub fn supports_vision(&self) -> bool {
217        self.capabilities.iter().any(|v| v == "vision")
218    }
219
220    pub fn supports_thinking(&self) -> bool {
221        self.capabilities.iter().any(|v| v == "thinking")
222    }
223}
224
225pub async fn complete(
226    client: &dyn HttpClient,
227    api_url: &str,
228    request: ChatRequest,
229) -> Result<ChatResponseDelta> {
230    let uri = format!("{api_url}/api/chat");
231    let request_builder = HttpRequest::builder()
232        .method(Method::POST)
233        .uri(uri)
234        .header("Content-Type", "application/json");
235
236    let serialized_request = serde_json::to_string(&request)?;
237    let request = request_builder.body(AsyncBody::from(serialized_request))?;
238
239    let mut response = client.send(request).await?;
240
241    let mut body = Vec::new();
242    response.body_mut().read_to_end(&mut body).await?;
243
244    if response.status().is_success() {
245        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
246        Ok(response_message)
247    } else {
248        let body_str = std::str::from_utf8(&body)?;
249        anyhow::bail!(
250            "Failed to connect to API: {} {}",
251            response.status(),
252            body_str
253        );
254    }
255}
256
257pub async fn stream_chat_completion(
258    client: &dyn HttpClient,
259    api_url: &str,
260    request: ChatRequest,
261) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
262    let uri = format!("{api_url}/api/chat");
263    let request_builder = http::Request::builder()
264        .method(Method::POST)
265        .uri(uri)
266        .header("Content-Type", "application/json");
267
268    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
269    let mut response = client.send(request).await?;
270    if response.status().is_success() {
271        let reader = BufReader::new(response.into_body());
272
273        Ok(reader
274            .lines()
275            .map(|line| match line {
276                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
277                Err(e) => Err(e.into()),
278            })
279            .boxed())
280    } else {
281        let mut body = String::new();
282        response.body_mut().read_to_string(&mut body).await?;
283        anyhow::bail!(
284            "Failed to connect to Ollama API: {} {}",
285            response.status(),
286            body,
287        );
288    }
289}
290
291pub async fn get_models(
292    client: &dyn HttpClient,
293    api_url: &str,
294    _: Option<Duration>,
295) -> Result<Vec<LocalModelListing>> {
296    let uri = format!("{api_url}/api/tags");
297    let request_builder = HttpRequest::builder()
298        .method(Method::GET)
299        .uri(uri)
300        .header("Accept", "application/json");
301
302    let request = request_builder.body(AsyncBody::default())?;
303
304    let mut response = client.send(request).await?;
305
306    let mut body = String::new();
307    response.body_mut().read_to_string(&mut body).await?;
308
309    anyhow::ensure!(
310        response.status().is_success(),
311        "Failed to connect to Ollama API: {} {}",
312        response.status(),
313        body,
314    );
315    let response: LocalModelsResponse =
316        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
317    Ok(response.models)
318}
319
320/// Fetch details of a model, used to determine model capabilities
321pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
322    let uri = format!("{api_url}/api/show");
323    let request = HttpRequest::builder()
324        .method(Method::POST)
325        .uri(uri)
326        .header("Content-Type", "application/json")
327        .body(AsyncBody::from(
328            serde_json::json!({ "model": model }).to_string(),
329        ))?;
330
331    let mut response = client.send(request).await?;
332    let mut body = String::new();
333    response.body_mut().read_to_string(&mut body).await?;
334
335    anyhow::ensure!(
336        response.status().is_success(),
337        "Failed to connect to Ollama API: {} {}",
338        response.status(),
339        body,
340    );
341    let details: ModelShow = serde_json::from_str(body.as_str())?;
342    Ok(details)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn parse_completion() {
351        let response = serde_json::json!({
352        "model": "llama3.2",
353        "created_at": "2023-12-12T14:13:43.416799Z",
354        "message": {
355            "role": "assistant",
356            "content": "Hello! How are you today?"
357        },
358        "done": true,
359        "total_duration": 5191566416u64,
360        "load_duration": 2154458,
361        "prompt_eval_count": 26,
362        "prompt_eval_duration": 383809000,
363        "eval_count": 298,
364        "eval_duration": 4799921000u64
365        });
366        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
367    }
368
369    #[test]
370    fn parse_streaming_completion() {
371        let partial = serde_json::json!({
372        "model": "llama3.2",
373        "created_at": "2023-08-04T08:52:19.385406455-07:00",
374        "message": {
375            "role": "assistant",
376            "content": "The",
377            "images": null
378        },
379        "done": false
380        });
381
382        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
383
384        let last = serde_json::json!({
385        "model": "llama3.2",
386        "created_at": "2023-08-04T19:22:45.499127Z",
387        "message": {
388            "role": "assistant",
389            "content": ""
390        },
391        "done": true,
392        "total_duration": 4883583458u64,
393        "load_duration": 1334875,
394        "prompt_eval_count": 26,
395        "prompt_eval_duration": 342546000,
396        "eval_count": 282,
397        "eval_duration": 4535599000u64
398        });
399
400        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
401    }
402
403    #[test]
404    fn parse_tool_call() {
405        let response = serde_json::json!({
406            "model": "llama3.2:3b",
407            "created_at": "2025-04-28T20:02:02.140489Z",
408            "message": {
409                "role": "assistant",
410                "content": "",
411                "tool_calls": [
412                    {
413                        "function": {
414                            "name": "weather",
415                            "arguments": {
416                                "city": "london",
417                            }
418                        }
419                    }
420                ]
421            },
422            "done_reason": "stop",
423            "done": true,
424            "total_duration": 2758629166u64,
425            "load_duration": 1770059875,
426            "prompt_eval_count": 147,
427            "prompt_eval_duration": 684637583,
428            "eval_count": 16,
429            "eval_duration": 302561917,
430        });
431
432        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
433        match result.message {
434            ChatMessage::Assistant {
435                content,
436                tool_calls,
437                images: _,
438                thinking,
439            } => {
440                assert!(content.is_empty());
441                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
442                assert!(thinking.is_none());
443            }
444            _ => panic!("Deserialized wrong role"),
445        }
446    }
447
448    #[test]
449    fn parse_show_model() {
450        let response = serde_json::json!({
451            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
452            "details": {
453                "parent_model": "",
454                "format": "gguf",
455                "family": "llama",
456                "families": ["llama"],
457                "parameter_size": "3.2B",
458                "quantization_level": "Q4_K_M"
459            },
460            "model_info": {
461                "general.architecture": "llama",
462                "general.basename": "Llama-3.2",
463                "general.file_type": 15,
464                "general.finetune": "Instruct",
465                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
466                "general.parameter_count": 3212749888u64,
467                "general.quantization_version": 2,
468                "general.size_label": "3B",
469                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
470                "general.type": "model",
471                "llama.attention.head_count": 24,
472                "llama.attention.head_count_kv": 8,
473                "llama.attention.key_length": 128,
474                "llama.attention.layer_norm_rms_epsilon": 0.00001,
475                "llama.attention.value_length": 128,
476                "llama.block_count": 28,
477                "llama.context_length": 131072,
478                "llama.embedding_length": 3072,
479                "llama.feed_forward_length": 8192,
480                "llama.rope.dimension_count": 128,
481                "llama.rope.freq_base": 500000,
482                "llama.vocab_size": 128256,
483                "tokenizer.ggml.bos_token_id": 128000,
484                "tokenizer.ggml.eos_token_id": 128009,
485                "tokenizer.ggml.merges": null,
486                "tokenizer.ggml.model": "gpt2",
487                "tokenizer.ggml.pre": "llama-bpe",
488                "tokenizer.ggml.token_type": null,
489                "tokenizer.ggml.tokens": null
490            },
491            "tensors": [
492                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
493                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
494            ],
495            "capabilities": ["completion", "tools"],
496            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
497        });
498
499        let result: ModelShow = serde_json::from_value(response).unwrap();
500        assert!(result.supports_tools());
501        assert!(result.capabilities.contains(&"tools".to_string()));
502        assert!(result.capabilities.contains(&"completion".to_string()));
503    }
504
505    #[test]
506    fn serialize_chat_request_with_images() {
507        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
508
509        let request = ChatRequest {
510            model: "llava".to_string(),
511            messages: vec![ChatMessage::User {
512                content: "What do you see in this image?".to_string(),
513                images: Some(vec![base64_image.to_string()]),
514            }],
515            stream: false,
516            keep_alive: KeepAlive::default(),
517            options: None,
518            think: None,
519            tools: vec![],
520        };
521
522        let serialized = serde_json::to_string(&request).unwrap();
523        assert!(serialized.contains("images"));
524        assert!(serialized.contains(base64_image));
525    }
526
527    #[test]
528    fn serialize_chat_request_without_images() {
529        let request = ChatRequest {
530            model: "llama3.2".to_string(),
531            messages: vec![ChatMessage::User {
532                content: "Hello, world!".to_string(),
533                images: None,
534            }],
535            stream: false,
536            keep_alive: KeepAlive::default(),
537            options: None,
538            think: None,
539            tools: vec![],
540        };
541
542        let serialized = serde_json::to_string(&request).unwrap();
543        assert!(!serialized.contains("images"));
544    }
545
546    #[test]
547    fn test_json_format_with_images() {
548        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
549
550        let request = ChatRequest {
551            model: "llava".to_string(),
552            messages: vec![ChatMessage::User {
553                content: "What do you see?".to_string(),
554                images: Some(vec![base64_image.to_string()]),
555            }],
556            stream: false,
557            keep_alive: KeepAlive::default(),
558            options: None,
559            think: None,
560            tools: vec![],
561        };
562
563        let serialized = serde_json::to_string(&request).unwrap();
564
565        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
566        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
567        assert_eq!(message_images.len(), 1);
568        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
569    }
570}