ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::KeepAlive;
  7use std::time::Duration;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(name: &str) -> u64 {
 24    /// Default context length for unknown models.
 25    const DEFAULT_TOKENS: u64 = 4096;
 26    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 27    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 28    const MAXIMUM_TOKENS: u64 = 16384;
 29
 30    match name.split(':').next().unwrap() {
 31        "granite-code" | "phi" | "tinyllama" => 2048,
 32        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 33        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 34        "codellama" | "starcoder2" => 16384,
 35        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 36        | "qwen2" | "qwen2.5-coder" => 32768,
 37        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 38        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 39        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 40        _ => DEFAULT_TOKENS,
 41    }
 42    .clamp(1, MAXIMUM_TOKENS)
 43}
 44
 45impl Model {
 46    pub fn new(
 47        name: &str,
 48        display_name: Option<&str>,
 49        max_tokens: Option<u64>,
 50        supports_tools: Option<bool>,
 51        supports_vision: Option<bool>,
 52        supports_thinking: Option<bool>,
 53    ) -> Self {
 54        Self {
 55            name: name.to_owned(),
 56            display_name: display_name
 57                .map(ToString::to_string)
 58                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 59            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 60            keep_alive: Some(KeepAlive::indefinite()),
 61            supports_tools,
 62            supports_vision,
 63            supports_thinking,
 64        }
 65    }
 66
 67    pub fn id(&self) -> &str {
 68        &self.name
 69    }
 70
 71    pub fn display_name(&self) -> &str {
 72        self.display_name.as_ref().unwrap_or(&self.name)
 73    }
 74
 75    pub fn max_token_count(&self) -> u64 {
 76        self.max_tokens
 77    }
 78}
 79
 80#[derive(Serialize, Deserialize, Debug)]
 81#[serde(tag = "role", rename_all = "lowercase")]
 82pub enum ChatMessage {
 83    Assistant {
 84        content: String,
 85        tool_calls: Option<Vec<OllamaToolCall>>,
 86        #[serde(skip_serializing_if = "Option::is_none")]
 87        images: Option<Vec<String>>,
 88        thinking: Option<String>,
 89    },
 90    User {
 91        content: String,
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        images: Option<Vec<String>>,
 94    },
 95    System {
 96        content: String,
 97    },
 98    Tool {
 99        tool_name: String,
100        content: String,
101    },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105#[serde(rename_all = "lowercase")]
106pub enum OllamaToolCall {
107    Function(OllamaFunctionCall),
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct OllamaFunctionCall {
112    pub name: String,
113    pub arguments: Value,
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionTool {
118    pub name: String,
119    pub description: Option<String>,
120    pub parameters: Option<Value>,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum OllamaTool {
126    Function { function: OllamaFunctionTool },
127}
128
129#[derive(Serialize, Debug)]
130pub struct ChatRequest {
131    pub model: String,
132    pub messages: Vec<ChatMessage>,
133    pub stream: bool,
134    pub keep_alive: KeepAlive,
135    pub options: Option<ChatOptions>,
136    pub tools: Vec<OllamaTool>,
137    pub think: Option<bool>,
138}
139
140impl ChatRequest {
141    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
142        self.stream = false;
143        self.tools = tools;
144        self
145    }
146}
147
148// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
149#[derive(Serialize, Default, Debug)]
150pub struct ChatOptions {
151    pub num_ctx: Option<u64>,
152    pub num_predict: Option<isize>,
153    pub stop: Option<Vec<String>>,
154    pub temperature: Option<f32>,
155    pub top_p: Option<f32>,
156}
157
158#[derive(Deserialize, Debug)]
159pub struct ChatResponseDelta {
160    #[allow(unused)]
161    pub model: String,
162    #[allow(unused)]
163    pub created_at: String,
164    pub message: ChatMessage,
165    #[allow(unused)]
166    pub done_reason: Option<String>,
167    #[allow(unused)]
168    pub done: bool,
169    pub prompt_eval_count: Option<u64>,
170    pub eval_count: Option<u64>,
171}
172
173#[derive(Serialize, Deserialize)]
174pub struct LocalModelsResponse {
175    pub models: Vec<LocalModelListing>,
176}
177
178#[derive(Serialize, Deserialize)]
179pub struct LocalModelListing {
180    pub name: String,
181    pub modified_at: String,
182    pub size: u64,
183    pub digest: String,
184    pub details: ModelDetails,
185}
186
187#[derive(Serialize, Deserialize)]
188pub struct LocalModel {
189    pub modelfile: String,
190    pub parameters: String,
191    pub template: String,
192    pub details: ModelDetails,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct ModelDetails {
197    pub format: String,
198    pub family: String,
199    pub families: Option<Vec<String>>,
200    pub parameter_size: String,
201    pub quantization_level: String,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct ModelShow {
206    #[serde(default)]
207    pub capabilities: Vec<String>,
208}
209
210impl ModelShow {
211    pub fn supports_tools(&self) -> bool {
212        // .contains expects &String, which would require an additional allocation
213        self.capabilities.iter().any(|v| v == "tools")
214    }
215
216    pub fn supports_vision(&self) -> bool {
217        self.capabilities.iter().any(|v| v == "vision")
218    }
219
220    pub fn supports_thinking(&self) -> bool {
221        self.capabilities.iter().any(|v| v == "thinking")
222    }
223}
224
225pub async fn complete(
226    client: &dyn HttpClient,
227    api_url: &str,
228    request: ChatRequest,
229) -> Result<ChatResponseDelta> {
230    let uri = format!("{api_url}/api/chat");
231    let request_builder = HttpRequest::builder()
232        .method(Method::POST)
233        .uri(uri)
234        .header("Content-Type", "application/json");
235
236    let serialized_request = serde_json::to_string(&request)?;
237    let request = request_builder.body(AsyncBody::from(serialized_request))?;
238
239    let mut response = client.send(request).await?;
240
241    let mut body = Vec::new();
242    response.body_mut().read_to_end(&mut body).await?;
243
244    if response.status().is_success() {
245        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
246        Ok(response_message)
247    } else {
248        let body_str = std::str::from_utf8(&body)?;
249        anyhow::bail!(
250            "Failed to connect to API: {} {}",
251            response.status(),
252            body_str
253        );
254    }
255}
256
257pub async fn stream_chat_completion(
258    client: &dyn HttpClient,
259    api_url: &str,
260    api_key: Option<&str>,
261    request: ChatRequest,
262) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
263    let uri = format!("{api_url}/api/chat");
264    let request = HttpRequest::builder()
265        .method(Method::POST)
266        .uri(uri)
267        .header("Content-Type", "application/json")
268        .when_some(api_key, |builder, api_key| {
269            builder.header("Authorization", format!("Bearer {api_key}"))
270        })
271        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
272
273    let mut response = client.send(request).await?;
274    if response.status().is_success() {
275        let reader = BufReader::new(response.into_body());
276
277        Ok(reader
278            .lines()
279            .map(|line| match line {
280                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
281                Err(e) => Err(e.into()),
282            })
283            .boxed())
284    } else {
285        let mut body = String::new();
286        response.body_mut().read_to_string(&mut body).await?;
287        anyhow::bail!(
288            "Failed to connect to Ollama API: {} {}",
289            response.status(),
290            body,
291        );
292    }
293}
294
295pub async fn get_models(
296    client: &dyn HttpClient,
297    api_url: &str,
298    api_key: Option<&str>,
299    _: Option<Duration>,
300) -> Result<Vec<LocalModelListing>> {
301    let uri = format!("{api_url}/api/tags");
302    let request = HttpRequest::builder()
303        .method(Method::GET)
304        .uri(uri)
305        .header("Accept", "application/json")
306        .when_some(api_key, |builder, api_key| {
307            builder.header("Authorization", format!("Bearer {api_key}"))
308        })
309        .body(AsyncBody::default())?;
310
311    let mut response = client.send(request).await?;
312
313    let mut body = String::new();
314    response.body_mut().read_to_string(&mut body).await?;
315
316    anyhow::ensure!(
317        response.status().is_success(),
318        "Failed to connect to Ollama API: {} {}",
319        response.status(),
320        body,
321    );
322    let response: LocalModelsResponse =
323        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
324    Ok(response.models)
325}
326
327/// Fetch details of a model, used to determine model capabilities
328pub async fn show_model(
329    client: &dyn HttpClient,
330    api_url: &str,
331    api_key: Option<&str>,
332    model: &str,
333) -> Result<ModelShow> {
334    let uri = format!("{api_url}/api/show");
335    let request = HttpRequest::builder()
336        .method(Method::POST)
337        .uri(uri)
338        .header("Content-Type", "application/json")
339        .when_some(api_key, |builder, api_key| {
340            builder.header("Authorization", format!("Bearer {api_key}"))
341        })
342        .body(AsyncBody::from(
343            serde_json::json!({ "model": model }).to_string(),
344        ))?;
345
346    let mut response = client.send(request).await?;
347    let mut body = String::new();
348    response.body_mut().read_to_string(&mut body).await?;
349
350    anyhow::ensure!(
351        response.status().is_success(),
352        "Failed to connect to Ollama API: {} {}",
353        response.status(),
354        body,
355    );
356    let details: ModelShow = serde_json::from_str(body.as_str())?;
357    Ok(details)
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn parse_completion() {
366        let response = serde_json::json!({
367        "model": "llama3.2",
368        "created_at": "2023-12-12T14:13:43.416799Z",
369        "message": {
370            "role": "assistant",
371            "content": "Hello! How are you today?"
372        },
373        "done": true,
374        "total_duration": 5191566416u64,
375        "load_duration": 2154458,
376        "prompt_eval_count": 26,
377        "prompt_eval_duration": 383809000,
378        "eval_count": 298,
379        "eval_duration": 4799921000u64
380        });
381        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
382    }
383
384    #[test]
385    fn parse_streaming_completion() {
386        let partial = serde_json::json!({
387        "model": "llama3.2",
388        "created_at": "2023-08-04T08:52:19.385406455-07:00",
389        "message": {
390            "role": "assistant",
391            "content": "The",
392            "images": null
393        },
394        "done": false
395        });
396
397        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
398
399        let last = serde_json::json!({
400        "model": "llama3.2",
401        "created_at": "2023-08-04T19:22:45.499127Z",
402        "message": {
403            "role": "assistant",
404            "content": ""
405        },
406        "done": true,
407        "total_duration": 4883583458u64,
408        "load_duration": 1334875,
409        "prompt_eval_count": 26,
410        "prompt_eval_duration": 342546000,
411        "eval_count": 282,
412        "eval_duration": 4535599000u64
413        });
414
415        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
416    }
417
418    #[test]
419    fn parse_tool_call() {
420        let response = serde_json::json!({
421            "model": "llama3.2:3b",
422            "created_at": "2025-04-28T20:02:02.140489Z",
423            "message": {
424                "role": "assistant",
425                "content": "",
426                "tool_calls": [
427                    {
428                        "function": {
429                            "name": "weather",
430                            "arguments": {
431                                "city": "london",
432                            }
433                        }
434                    }
435                ]
436            },
437            "done_reason": "stop",
438            "done": true,
439            "total_duration": 2758629166u64,
440            "load_duration": 1770059875,
441            "prompt_eval_count": 147,
442            "prompt_eval_duration": 684637583,
443            "eval_count": 16,
444            "eval_duration": 302561917,
445        });
446
447        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
448        match result.message {
449            ChatMessage::Assistant {
450                content,
451                tool_calls,
452                images: _,
453                thinking,
454            } => {
455                assert!(content.is_empty());
456                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
457                assert!(thinking.is_none());
458            }
459            _ => panic!("Deserialized wrong role"),
460        }
461    }
462
463    #[test]
464    fn parse_show_model() {
465        let response = serde_json::json!({
466            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
467            "details": {
468                "parent_model": "",
469                "format": "gguf",
470                "family": "llama",
471                "families": ["llama"],
472                "parameter_size": "3.2B",
473                "quantization_level": "Q4_K_M"
474            },
475            "model_info": {
476                "general.architecture": "llama",
477                "general.basename": "Llama-3.2",
478                "general.file_type": 15,
479                "general.finetune": "Instruct",
480                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
481                "general.parameter_count": 3212749888u64,
482                "general.quantization_version": 2,
483                "general.size_label": "3B",
484                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
485                "general.type": "model",
486                "llama.attention.head_count": 24,
487                "llama.attention.head_count_kv": 8,
488                "llama.attention.key_length": 128,
489                "llama.attention.layer_norm_rms_epsilon": 0.00001,
490                "llama.attention.value_length": 128,
491                "llama.block_count": 28,
492                "llama.context_length": 131072,
493                "llama.embedding_length": 3072,
494                "llama.feed_forward_length": 8192,
495                "llama.rope.dimension_count": 128,
496                "llama.rope.freq_base": 500000,
497                "llama.vocab_size": 128256,
498                "tokenizer.ggml.bos_token_id": 128000,
499                "tokenizer.ggml.eos_token_id": 128009,
500                "tokenizer.ggml.merges": null,
501                "tokenizer.ggml.model": "gpt2",
502                "tokenizer.ggml.pre": "llama-bpe",
503                "tokenizer.ggml.token_type": null,
504                "tokenizer.ggml.tokens": null
505            },
506            "tensors": [
507                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
508                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
509            ],
510            "capabilities": ["completion", "tools"],
511            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
512        });
513
514        let result: ModelShow = serde_json::from_value(response).unwrap();
515        assert!(result.supports_tools());
516        assert!(result.capabilities.contains(&"tools".to_string()));
517        assert!(result.capabilities.contains(&"completion".to_string()));
518    }
519
520    #[test]
521    fn serialize_chat_request_with_images() {
522        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
523
524        let request = ChatRequest {
525            model: "llava".to_string(),
526            messages: vec![ChatMessage::User {
527                content: "What do you see in this image?".to_string(),
528                images: Some(vec![base64_image.to_string()]),
529            }],
530            stream: false,
531            keep_alive: KeepAlive::default(),
532            options: None,
533            think: None,
534            tools: vec![],
535        };
536
537        let serialized = serde_json::to_string(&request).unwrap();
538        assert!(serialized.contains("images"));
539        assert!(serialized.contains(base64_image));
540    }
541
542    #[test]
543    fn serialize_chat_request_without_images() {
544        let request = ChatRequest {
545            model: "llama3.2".to_string(),
546            messages: vec![ChatMessage::User {
547                content: "Hello, world!".to_string(),
548                images: None,
549            }],
550            stream: false,
551            keep_alive: KeepAlive::default(),
552            options: None,
553            think: None,
554            tools: vec![],
555        };
556
557        let serialized = serde_json::to_string(&request).unwrap();
558        assert!(!serialized.contains("images"));
559    }
560
561    #[test]
562    fn test_json_format_with_images() {
563        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
564
565        let request = ChatRequest {
566            model: "llava".to_string(),
567            messages: vec![ChatMessage::User {
568                content: "What do you see?".to_string(),
569                images: Some(vec![base64_image.to_string()]),
570            }],
571            stream: false,
572            keep_alive: KeepAlive::default(),
573            options: None,
574            think: None,
575            tools: vec![],
576        };
577
578        let serialized = serde_json::to_string(&request).unwrap();
579
580        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
581        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
582        assert_eq!(message_images.len(), 1);
583        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
584    }
585}