ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::KeepAlive;
  7use std::time::Duration;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(name: &str) -> u64 {
 24    /// Default context length for unknown models.
 25    const DEFAULT_TOKENS: u64 = 4096;
 26    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 27    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 28    const MAXIMUM_TOKENS: u64 = 16384;
 29
 30    match name.split(':').next().unwrap() {
 31        "granite-code" | "phi" | "tinyllama" => 2048,
 32        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 33        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 34        "codellama" | "starcoder2" => 16384,
 35        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 36        | "qwen2" | "qwen2.5-coder" => 32768,
 37        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 38        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 39        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 40        "qwen3-coder" => 256000,
 41        _ => DEFAULT_TOKENS,
 42    }
 43    .clamp(1, MAXIMUM_TOKENS)
 44}
 45
 46impl Model {
 47    pub fn new(
 48        name: &str,
 49        display_name: Option<&str>,
 50        max_tokens: Option<u64>,
 51        supports_tools: Option<bool>,
 52        supports_vision: Option<bool>,
 53        supports_thinking: Option<bool>,
 54    ) -> Self {
 55        Self {
 56            name: name.to_owned(),
 57            display_name: display_name
 58                .map(ToString::to_string)
 59                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 60            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 61            keep_alive: Some(KeepAlive::indefinite()),
 62            supports_tools,
 63            supports_vision,
 64            supports_thinking,
 65        }
 66    }
 67
 68    pub fn id(&self) -> &str {
 69        &self.name
 70    }
 71
 72    pub fn display_name(&self) -> &str {
 73        self.display_name.as_ref().unwrap_or(&self.name)
 74    }
 75
 76    pub fn max_token_count(&self) -> u64 {
 77        self.max_tokens
 78    }
 79}
 80
 81#[derive(Serialize, Deserialize, Debug)]
 82#[serde(tag = "role", rename_all = "lowercase")]
 83pub enum ChatMessage {
 84    Assistant {
 85        content: String,
 86        tool_calls: Option<Vec<OllamaToolCall>>,
 87        #[serde(skip_serializing_if = "Option::is_none")]
 88        images: Option<Vec<String>>,
 89        thinking: Option<String>,
 90    },
 91    User {
 92        content: String,
 93        #[serde(skip_serializing_if = "Option::is_none")]
 94        images: Option<Vec<String>>,
 95    },
 96    System {
 97        content: String,
 98    },
 99    Tool {
100        tool_name: String,
101        content: String,
102    },
103}
104
105#[derive(Serialize, Deserialize, Debug)]
106#[serde(rename_all = "lowercase")]
107pub enum OllamaToolCall {
108    Function(OllamaFunctionCall),
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112pub struct OllamaFunctionCall {
113    pub name: String,
114    pub arguments: Value,
115}
116
117#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
118pub struct OllamaFunctionTool {
119    pub name: String,
120    pub description: Option<String>,
121    pub parameters: Option<Value>,
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125#[serde(tag = "type", rename_all = "lowercase")]
126pub enum OllamaTool {
127    Function { function: OllamaFunctionTool },
128}
129
130#[derive(Serialize, Debug)]
131pub struct ChatRequest {
132    pub model: String,
133    pub messages: Vec<ChatMessage>,
134    pub stream: bool,
135    pub keep_alive: KeepAlive,
136    pub options: Option<ChatOptions>,
137    pub tools: Vec<OllamaTool>,
138    pub think: Option<bool>,
139}
140
141impl ChatRequest {
142    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
143        self.stream = false;
144        self.tools = tools;
145        self
146    }
147}
148
149// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
150#[derive(Serialize, Default, Debug)]
151pub struct ChatOptions {
152    pub num_ctx: Option<u64>,
153    pub num_predict: Option<isize>,
154    pub stop: Option<Vec<String>>,
155    pub temperature: Option<f32>,
156    pub top_p: Option<f32>,
157}
158
159#[derive(Deserialize, Debug)]
160pub struct ChatResponseDelta {
161    #[allow(unused)]
162    pub model: String,
163    #[allow(unused)]
164    pub created_at: String,
165    pub message: ChatMessage,
166    #[allow(unused)]
167    pub done_reason: Option<String>,
168    #[allow(unused)]
169    pub done: bool,
170    pub prompt_eval_count: Option<u64>,
171    pub eval_count: Option<u64>,
172}
173
174#[derive(Serialize, Deserialize)]
175pub struct LocalModelsResponse {
176    pub models: Vec<LocalModelListing>,
177}
178
179#[derive(Serialize, Deserialize)]
180pub struct LocalModelListing {
181    pub name: String,
182    pub modified_at: String,
183    pub size: u64,
184    pub digest: String,
185    pub details: ModelDetails,
186}
187
188#[derive(Serialize, Deserialize)]
189pub struct LocalModel {
190    pub modelfile: String,
191    pub parameters: String,
192    pub template: String,
193    pub details: ModelDetails,
194}
195
196#[derive(Serialize, Deserialize)]
197pub struct ModelDetails {
198    pub format: String,
199    pub family: String,
200    pub families: Option<Vec<String>>,
201    pub parameter_size: String,
202    pub quantization_level: String,
203}
204
205#[derive(Deserialize, Debug)]
206pub struct ModelShow {
207    #[serde(default)]
208    pub capabilities: Vec<String>,
209}
210
211impl ModelShow {
212    pub fn supports_tools(&self) -> bool {
213        // .contains expects &String, which would require an additional allocation
214        self.capabilities.iter().any(|v| v == "tools")
215    }
216
217    pub fn supports_vision(&self) -> bool {
218        self.capabilities.iter().any(|v| v == "vision")
219    }
220
221    pub fn supports_thinking(&self) -> bool {
222        self.capabilities.iter().any(|v| v == "thinking")
223    }
224}
225
226pub async fn complete(
227    client: &dyn HttpClient,
228    api_url: &str,
229    request: ChatRequest,
230) -> Result<ChatResponseDelta> {
231    let uri = format!("{api_url}/api/chat");
232    let request_builder = HttpRequest::builder()
233        .method(Method::POST)
234        .uri(uri)
235        .header("Content-Type", "application/json");
236
237    let serialized_request = serde_json::to_string(&request)?;
238    let request = request_builder.body(AsyncBody::from(serialized_request))?;
239
240    let mut response = client.send(request).await?;
241
242    let mut body = Vec::new();
243    response.body_mut().read_to_end(&mut body).await?;
244
245    if response.status().is_success() {
246        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
247        Ok(response_message)
248    } else {
249        let body_str = std::str::from_utf8(&body)?;
250        anyhow::bail!(
251            "Failed to connect to API: {} {}",
252            response.status(),
253            body_str
254        );
255    }
256}
257
258pub async fn stream_chat_completion(
259    client: &dyn HttpClient,
260    api_url: &str,
261    api_key: Option<&str>,
262    request: ChatRequest,
263) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
264    let uri = format!("{api_url}/api/chat");
265    let request = HttpRequest::builder()
266        .method(Method::POST)
267        .uri(uri)
268        .header("Content-Type", "application/json")
269        .when_some(api_key, |builder, api_key| {
270            builder.header("Authorization", format!("Bearer {api_key}"))
271        })
272        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
273
274    let mut response = client.send(request).await?;
275    if response.status().is_success() {
276        let reader = BufReader::new(response.into_body());
277
278        Ok(reader
279            .lines()
280            .map(|line| match line {
281                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
282                Err(e) => Err(e.into()),
283            })
284            .boxed())
285    } else {
286        let mut body = String::new();
287        response.body_mut().read_to_string(&mut body).await?;
288        anyhow::bail!(
289            "Failed to connect to Ollama API: {} {}",
290            response.status(),
291            body,
292        );
293    }
294}
295
296pub async fn get_models(
297    client: &dyn HttpClient,
298    api_url: &str,
299    api_key: Option<&str>,
300    _: Option<Duration>,
301) -> Result<Vec<LocalModelListing>> {
302    let uri = format!("{api_url}/api/tags");
303    let request = HttpRequest::builder()
304        .method(Method::GET)
305        .uri(uri)
306        .header("Accept", "application/json")
307        .when_some(api_key, |builder, api_key| {
308            builder.header("Authorization", format!("Bearer {api_key}"))
309        })
310        .body(AsyncBody::default())?;
311
312    let mut response = client.send(request).await?;
313
314    let mut body = String::new();
315    response.body_mut().read_to_string(&mut body).await?;
316
317    anyhow::ensure!(
318        response.status().is_success(),
319        "Failed to connect to Ollama API: {} {}",
320        response.status(),
321        body,
322    );
323    let response: LocalModelsResponse =
324        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
325    Ok(response.models)
326}
327
328/// Fetch details of a model, used to determine model capabilities
329pub async fn show_model(
330    client: &dyn HttpClient,
331    api_url: &str,
332    api_key: Option<&str>,
333    model: &str,
334) -> Result<ModelShow> {
335    let uri = format!("{api_url}/api/show");
336    let request = HttpRequest::builder()
337        .method(Method::POST)
338        .uri(uri)
339        .header("Content-Type", "application/json")
340        .when_some(api_key, |builder, api_key| {
341            builder.header("Authorization", format!("Bearer {api_key}"))
342        })
343        .body(AsyncBody::from(
344            serde_json::json!({ "model": model }).to_string(),
345        ))?;
346
347    let mut response = client.send(request).await?;
348    let mut body = String::new();
349    response.body_mut().read_to_string(&mut body).await?;
350
351    anyhow::ensure!(
352        response.status().is_success(),
353        "Failed to connect to Ollama API: {} {}",
354        response.status(),
355        body,
356    );
357    let details: ModelShow = serde_json::from_str(body.as_str())?;
358    Ok(details)
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn parse_completion() {
367        let response = serde_json::json!({
368        "model": "llama3.2",
369        "created_at": "2023-12-12T14:13:43.416799Z",
370        "message": {
371            "role": "assistant",
372            "content": "Hello! How are you today?"
373        },
374        "done": true,
375        "total_duration": 5191566416u64,
376        "load_duration": 2154458,
377        "prompt_eval_count": 26,
378        "prompt_eval_duration": 383809000,
379        "eval_count": 298,
380        "eval_duration": 4799921000u64
381        });
382        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
383    }
384
385    #[test]
386    fn parse_streaming_completion() {
387        let partial = serde_json::json!({
388        "model": "llama3.2",
389        "created_at": "2023-08-04T08:52:19.385406455-07:00",
390        "message": {
391            "role": "assistant",
392            "content": "The",
393            "images": null
394        },
395        "done": false
396        });
397
398        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
399
400        let last = serde_json::json!({
401        "model": "llama3.2",
402        "created_at": "2023-08-04T19:22:45.499127Z",
403        "message": {
404            "role": "assistant",
405            "content": ""
406        },
407        "done": true,
408        "total_duration": 4883583458u64,
409        "load_duration": 1334875,
410        "prompt_eval_count": 26,
411        "prompt_eval_duration": 342546000,
412        "eval_count": 282,
413        "eval_duration": 4535599000u64
414        });
415
416        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
417    }
418
419    #[test]
420    fn parse_tool_call() {
421        let response = serde_json::json!({
422            "model": "llama3.2:3b",
423            "created_at": "2025-04-28T20:02:02.140489Z",
424            "message": {
425                "role": "assistant",
426                "content": "",
427                "tool_calls": [
428                    {
429                        "function": {
430                            "name": "weather",
431                            "arguments": {
432                                "city": "london",
433                            }
434                        }
435                    }
436                ]
437            },
438            "done_reason": "stop",
439            "done": true,
440            "total_duration": 2758629166u64,
441            "load_duration": 1770059875,
442            "prompt_eval_count": 147,
443            "prompt_eval_duration": 684637583,
444            "eval_count": 16,
445            "eval_duration": 302561917,
446        });
447
448        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
449        match result.message {
450            ChatMessage::Assistant {
451                content,
452                tool_calls,
453                images: _,
454                thinking,
455            } => {
456                assert!(content.is_empty());
457                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
458                assert!(thinking.is_none());
459            }
460            _ => panic!("Deserialized wrong role"),
461        }
462    }
463
464    #[test]
465    fn parse_show_model() {
466        let response = serde_json::json!({
467            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
468            "details": {
469                "parent_model": "",
470                "format": "gguf",
471                "family": "llama",
472                "families": ["llama"],
473                "parameter_size": "3.2B",
474                "quantization_level": "Q4_K_M"
475            },
476            "model_info": {
477                "general.architecture": "llama",
478                "general.basename": "Llama-3.2",
479                "general.file_type": 15,
480                "general.finetune": "Instruct",
481                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
482                "general.parameter_count": 3212749888u64,
483                "general.quantization_version": 2,
484                "general.size_label": "3B",
485                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
486                "general.type": "model",
487                "llama.attention.head_count": 24,
488                "llama.attention.head_count_kv": 8,
489                "llama.attention.key_length": 128,
490                "llama.attention.layer_norm_rms_epsilon": 0.00001,
491                "llama.attention.value_length": 128,
492                "llama.block_count": 28,
493                "llama.context_length": 131072,
494                "llama.embedding_length": 3072,
495                "llama.feed_forward_length": 8192,
496                "llama.rope.dimension_count": 128,
497                "llama.rope.freq_base": 500000,
498                "llama.vocab_size": 128256,
499                "tokenizer.ggml.bos_token_id": 128000,
500                "tokenizer.ggml.eos_token_id": 128009,
501                "tokenizer.ggml.merges": null,
502                "tokenizer.ggml.model": "gpt2",
503                "tokenizer.ggml.pre": "llama-bpe",
504                "tokenizer.ggml.token_type": null,
505                "tokenizer.ggml.tokens": null
506            },
507            "tensors": [
508                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
509                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
510            ],
511            "capabilities": ["completion", "tools"],
512            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
513        });
514
515        let result: ModelShow = serde_json::from_value(response).unwrap();
516        assert!(result.supports_tools());
517        assert!(result.capabilities.contains(&"tools".to_string()));
518        assert!(result.capabilities.contains(&"completion".to_string()));
519    }
520
521    #[test]
522    fn serialize_chat_request_with_images() {
523        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
524
525        let request = ChatRequest {
526            model: "llava".to_string(),
527            messages: vec![ChatMessage::User {
528                content: "What do you see in this image?".to_string(),
529                images: Some(vec![base64_image.to_string()]),
530            }],
531            stream: false,
532            keep_alive: KeepAlive::default(),
533            options: None,
534            think: None,
535            tools: vec![],
536        };
537
538        let serialized = serde_json::to_string(&request).unwrap();
539        assert!(serialized.contains("images"));
540        assert!(serialized.contains(base64_image));
541    }
542
543    #[test]
544    fn serialize_chat_request_without_images() {
545        let request = ChatRequest {
546            model: "llama3.2".to_string(),
547            messages: vec![ChatMessage::User {
548                content: "Hello, world!".to_string(),
549                images: None,
550            }],
551            stream: false,
552            keep_alive: KeepAlive::default(),
553            options: None,
554            think: None,
555            tools: vec![],
556        };
557
558        let serialized = serde_json::to_string(&request).unwrap();
559        assert!(!serialized.contains("images"));
560    }
561
562    #[test]
563    fn test_json_format_with_images() {
564        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
565
566        let request = ChatRequest {
567            model: "llava".to_string(),
568            messages: vec![ChatMessage::User {
569                content: "What do you see?".to_string(),
570                images: Some(vec![base64_image.to_string()]),
571            }],
572            stream: false,
573            keep_alive: KeepAlive::default(),
574            options: None,
575            think: None,
576            tools: vec![],
577        };
578
579        let serialized = serde_json::to_string(&request).unwrap();
580
581        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
582        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
583        assert_eq!(message_images.len(), 1);
584        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
585    }
586}