ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::KeepAlive;
  7use std::time::Duration;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(name: &str) -> u64 {
 24    /// Default context length for unknown models.
 25    const DEFAULT_TOKENS: u64 = 4096;
 26    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 27    /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
 28    const MAXIMUM_TOKENS: u64 = 16384;
 29
 30    match name.split(':').next().unwrap() {
 31        "granite-code" | "phi" | "tinyllama" => 2048,
 32        "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
 33        "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
 34        "codellama" | "starcoder2" => 16384,
 35        "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
 36        | "qwen2" | "qwen2.5-coder" => 32768,
 37        "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
 38        | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
 39        | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
 40        _ => DEFAULT_TOKENS,
 41    }
 42    .clamp(1, MAXIMUM_TOKENS)
 43}
 44
 45impl Model {
 46    pub fn new(
 47        name: &str,
 48        display_name: Option<&str>,
 49        max_tokens: Option<u64>,
 50        supports_tools: Option<bool>,
 51        supports_vision: Option<bool>,
 52        supports_thinking: Option<bool>,
 53    ) -> Self {
 54        Self {
 55            name: name.to_owned(),
 56            display_name: display_name
 57                .map(ToString::to_string)
 58                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 59            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 60            keep_alive: Some(KeepAlive::indefinite()),
 61            supports_tools,
 62            supports_vision,
 63            supports_thinking,
 64        }
 65    }
 66
 67    pub fn id(&self) -> &str {
 68        &self.name
 69    }
 70
 71    pub fn display_name(&self) -> &str {
 72        self.display_name.as_ref().unwrap_or(&self.name)
 73    }
 74
 75    pub fn max_token_count(&self) -> u64 {
 76        self.max_tokens
 77    }
 78}
 79
 80#[derive(Serialize, Deserialize, Debug)]
 81#[serde(tag = "role", rename_all = "lowercase")]
 82pub enum ChatMessage {
 83    Assistant {
 84        content: String,
 85        tool_calls: Option<Vec<OllamaToolCall>>,
 86        #[serde(skip_serializing_if = "Option::is_none")]
 87        images: Option<Vec<String>>,
 88        thinking: Option<String>,
 89    },
 90    User {
 91        content: String,
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        images: Option<Vec<String>>,
 94    },
 95    System {
 96        content: String,
 97    },
 98    Tool {
 99        tool_name: String,
100        content: String,
101    },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105#[serde(rename_all = "lowercase")]
106pub enum OllamaToolCall {
107    Function(OllamaFunctionCall),
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct OllamaFunctionCall {
112    pub name: String,
113    pub arguments: Value,
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionTool {
118    pub name: String,
119    pub description: Option<String>,
120    pub parameters: Option<Value>,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum OllamaTool {
126    Function { function: OllamaFunctionTool },
127}
128
129#[derive(Serialize, Debug)]
130pub struct ChatRequest {
131    pub model: String,
132    pub messages: Vec<ChatMessage>,
133    pub stream: bool,
134    pub keep_alive: KeepAlive,
135    pub options: Option<ChatOptions>,
136    pub tools: Vec<OllamaTool>,
137    pub think: Option<bool>,
138}
139
140impl ChatRequest {
141    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
142        self.stream = false;
143        self.tools = tools;
144        self
145    }
146}
147
148// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
149#[derive(Serialize, Default, Debug)]
150pub struct ChatOptions {
151    pub num_ctx: Option<u64>,
152    pub num_predict: Option<isize>,
153    pub stop: Option<Vec<String>>,
154    pub temperature: Option<f32>,
155    pub top_p: Option<f32>,
156}
157
158#[derive(Deserialize, Debug)]
159pub struct ChatResponseDelta {
160    #[allow(unused)]
161    pub model: String,
162    #[allow(unused)]
163    pub created_at: String,
164    pub message: ChatMessage,
165    #[allow(unused)]
166    pub done_reason: Option<String>,
167    #[allow(unused)]
168    pub done: bool,
169    pub prompt_eval_count: Option<u64>,
170    pub eval_count: Option<u64>,
171}
172
173#[derive(Serialize, Deserialize)]
174pub struct LocalModelsResponse {
175    pub models: Vec<LocalModelListing>,
176}
177
178#[derive(Serialize, Deserialize)]
179pub struct LocalModelListing {
180    pub name: String,
181    pub modified_at: String,
182    pub size: u64,
183    pub digest: String,
184    pub details: ModelDetails,
185}
186
187#[derive(Serialize, Deserialize)]
188pub struct LocalModel {
189    pub modelfile: String,
190    pub parameters: String,
191    pub template: String,
192    pub details: ModelDetails,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct ModelDetails {
197    pub format: String,
198    pub family: String,
199    pub families: Option<Vec<String>>,
200    pub parameter_size: String,
201    pub quantization_level: String,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct ModelShow {
206    #[serde(default)]
207    pub capabilities: Vec<String>,
208}
209
210impl ModelShow {
211    pub fn supports_tools(&self) -> bool {
212        // .contains expects &String, which would require an additional allocation
213        self.capabilities.iter().any(|v| v == "tools")
214    }
215
216    pub fn supports_vision(&self) -> bool {
217        self.capabilities.iter().any(|v| v == "vision")
218    }
219
220    pub fn supports_thinking(&self) -> bool {
221        self.capabilities.iter().any(|v| v == "thinking")
222    }
223}
224
225pub async fn complete(
226    client: &dyn HttpClient,
227    api_url: &str,
228    request: ChatRequest,
229) -> Result<ChatResponseDelta> {
230    let uri = format!("{api_url}/api/chat");
231    let request_builder = HttpRequest::builder()
232        .method(Method::POST)
233        .uri(uri)
234        .header("Content-Type", "application/json");
235
236    let serialized_request = serde_json::to_string(&request)?;
237    let request = request_builder.body(AsyncBody::from(serialized_request))?;
238
239    let mut response = client.send(request).await?;
240
241    let mut body = Vec::new();
242    response.body_mut().read_to_end(&mut body).await?;
243
244    if response.status().is_success() {
245        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
246        Ok(response_message)
247    } else {
248        let body_str = std::str::from_utf8(&body)?;
249        anyhow::bail!(
250            "Failed to connect to API: {} {}",
251            response.status(),
252            body_str
253        );
254    }
255}
256
257pub async fn stream_chat_completion(
258    client: &dyn HttpClient,
259    api_url: &str,
260    api_key: Option<&str>,
261    request: ChatRequest,
262) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
263    let uri = format!("{api_url}/api/chat");
264    let mut request_builder = http::Request::builder()
265        .method(Method::POST)
266        .uri(uri)
267        .header("Content-Type", "application/json");
268
269    if let Some(api_key) = api_key {
270        request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
271    }
272
273    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
274    let mut response = client.send(request).await?;
275    if response.status().is_success() {
276        let reader = BufReader::new(response.into_body());
277
278        Ok(reader
279            .lines()
280            .map(|line| match line {
281                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
282                Err(e) => Err(e.into()),
283            })
284            .boxed())
285    } else {
286        let mut body = String::new();
287        response.body_mut().read_to_string(&mut body).await?;
288        anyhow::bail!(
289            "Failed to connect to Ollama API: {} {}",
290            response.status(),
291            body,
292        );
293    }
294}
295
296pub async fn get_models(
297    client: &dyn HttpClient,
298    api_url: &str,
299    api_key: Option<&str>,
300    _: Option<Duration>,
301) -> Result<Vec<LocalModelListing>> {
302    let uri = format!("{api_url}/api/tags");
303    let mut request_builder = HttpRequest::builder()
304        .method(Method::GET)
305        .uri(uri)
306        .header("Accept", "application/json");
307
308    if let Some(api_key) = api_key {
309        request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
310    }
311
312    let request = request_builder.body(AsyncBody::default())?;
313
314    let mut response = client.send(request).await?;
315
316    let mut body = String::new();
317    response.body_mut().read_to_string(&mut body).await?;
318
319    anyhow::ensure!(
320        response.status().is_success(),
321        "Failed to connect to Ollama API: {} {}",
322        response.status(),
323        body,
324    );
325    let response: LocalModelsResponse =
326        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
327    Ok(response.models)
328}
329
330/// Fetch details of a model, used to determine model capabilities
331pub async fn show_model(
332    client: &dyn HttpClient,
333    api_url: &str,
334    api_key: Option<&str>,
335    model: &str,
336) -> Result<ModelShow> {
337    let uri = format!("{api_url}/api/show");
338    let mut request_builder = HttpRequest::builder()
339        .method(Method::POST)
340        .uri(uri)
341        .header("Content-Type", "application/json");
342
343    if let Some(api_key) = api_key {
344        request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
345    }
346
347    let request = request_builder.body(AsyncBody::from(
348        serde_json::json!({ "model": model }).to_string(),
349    ))?;
350
351    let mut response = client.send(request).await?;
352    let mut body = String::new();
353    response.body_mut().read_to_string(&mut body).await?;
354
355    anyhow::ensure!(
356        response.status().is_success(),
357        "Failed to connect to Ollama API: {} {}",
358        response.status(),
359        body,
360    );
361    let details: ModelShow = serde_json::from_str(body.as_str())?;
362    Ok(details)
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn parse_completion() {
371        let response = serde_json::json!({
372        "model": "llama3.2",
373        "created_at": "2023-12-12T14:13:43.416799Z",
374        "message": {
375            "role": "assistant",
376            "content": "Hello! How are you today?"
377        },
378        "done": true,
379        "total_duration": 5191566416u64,
380        "load_duration": 2154458,
381        "prompt_eval_count": 26,
382        "prompt_eval_duration": 383809000,
383        "eval_count": 298,
384        "eval_duration": 4799921000u64
385        });
386        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
387    }
388
389    #[test]
390    fn parse_streaming_completion() {
391        let partial = serde_json::json!({
392        "model": "llama3.2",
393        "created_at": "2023-08-04T08:52:19.385406455-07:00",
394        "message": {
395            "role": "assistant",
396            "content": "The",
397            "images": null
398        },
399        "done": false
400        });
401
402        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
403
404        let last = serde_json::json!({
405        "model": "llama3.2",
406        "created_at": "2023-08-04T19:22:45.499127Z",
407        "message": {
408            "role": "assistant",
409            "content": ""
410        },
411        "done": true,
412        "total_duration": 4883583458u64,
413        "load_duration": 1334875,
414        "prompt_eval_count": 26,
415        "prompt_eval_duration": 342546000,
416        "eval_count": 282,
417        "eval_duration": 4535599000u64
418        });
419
420        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
421    }
422
423    #[test]
424    fn parse_tool_call() {
425        let response = serde_json::json!({
426            "model": "llama3.2:3b",
427            "created_at": "2025-04-28T20:02:02.140489Z",
428            "message": {
429                "role": "assistant",
430                "content": "",
431                "tool_calls": [
432                    {
433                        "function": {
434                            "name": "weather",
435                            "arguments": {
436                                "city": "london",
437                            }
438                        }
439                    }
440                ]
441            },
442            "done_reason": "stop",
443            "done": true,
444            "total_duration": 2758629166u64,
445            "load_duration": 1770059875,
446            "prompt_eval_count": 147,
447            "prompt_eval_duration": 684637583,
448            "eval_count": 16,
449            "eval_duration": 302561917,
450        });
451
452        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
453        match result.message {
454            ChatMessage::Assistant {
455                content,
456                tool_calls,
457                images: _,
458                thinking,
459            } => {
460                assert!(content.is_empty());
461                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
462                assert!(thinking.is_none());
463            }
464            _ => panic!("Deserialized wrong role"),
465        }
466    }
467
468    #[test]
469    fn parse_show_model() {
470        let response = serde_json::json!({
471            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
472            "details": {
473                "parent_model": "",
474                "format": "gguf",
475                "family": "llama",
476                "families": ["llama"],
477                "parameter_size": "3.2B",
478                "quantization_level": "Q4_K_M"
479            },
480            "model_info": {
481                "general.architecture": "llama",
482                "general.basename": "Llama-3.2",
483                "general.file_type": 15,
484                "general.finetune": "Instruct",
485                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
486                "general.parameter_count": 3212749888u64,
487                "general.quantization_version": 2,
488                "general.size_label": "3B",
489                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
490                "general.type": "model",
491                "llama.attention.head_count": 24,
492                "llama.attention.head_count_kv": 8,
493                "llama.attention.key_length": 128,
494                "llama.attention.layer_norm_rms_epsilon": 0.00001,
495                "llama.attention.value_length": 128,
496                "llama.block_count": 28,
497                "llama.context_length": 131072,
498                "llama.embedding_length": 3072,
499                "llama.feed_forward_length": 8192,
500                "llama.rope.dimension_count": 128,
501                "llama.rope.freq_base": 500000,
502                "llama.vocab_size": 128256,
503                "tokenizer.ggml.bos_token_id": 128000,
504                "tokenizer.ggml.eos_token_id": 128009,
505                "tokenizer.ggml.merges": null,
506                "tokenizer.ggml.model": "gpt2",
507                "tokenizer.ggml.pre": "llama-bpe",
508                "tokenizer.ggml.token_type": null,
509                "tokenizer.ggml.tokens": null
510            },
511            "tensors": [
512                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
513                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
514            ],
515            "capabilities": ["completion", "tools"],
516            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
517        });
518
519        let result: ModelShow = serde_json::from_value(response).unwrap();
520        assert!(result.supports_tools());
521        assert!(result.capabilities.contains(&"tools".to_string()));
522        assert!(result.capabilities.contains(&"completion".to_string()));
523    }
524
525    #[test]
526    fn serialize_chat_request_with_images() {
527        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
528
529        let request = ChatRequest {
530            model: "llava".to_string(),
531            messages: vec![ChatMessage::User {
532                content: "What do you see in this image?".to_string(),
533                images: Some(vec![base64_image.to_string()]),
534            }],
535            stream: false,
536            keep_alive: KeepAlive::default(),
537            options: None,
538            think: None,
539            tools: vec![],
540        };
541
542        let serialized = serde_json::to_string(&request).unwrap();
543        assert!(serialized.contains("images"));
544        assert!(serialized.contains(base64_image));
545    }
546
547    #[test]
548    fn serialize_chat_request_without_images() {
549        let request = ChatRequest {
550            model: "llama3.2".to_string(),
551            messages: vec![ChatMessage::User {
552                content: "Hello, world!".to_string(),
553                images: None,
554            }],
555            stream: false,
556            keep_alive: KeepAlive::default(),
557            options: None,
558            think: None,
559            tools: vec![],
560        };
561
562        let serialized = serde_json::to_string(&request).unwrap();
563        assert!(!serialized.contains("images"));
564    }
565
566    #[test]
567    fn test_json_format_with_images() {
568        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
569
570        let request = ChatRequest {
571            model: "llava".to_string(),
572            messages: vec![ChatMessage::User {
573                content: "What do you see?".to_string(),
574                images: Some(vec![base64_image.to_string()]),
575            }],
576            stream: false,
577            keep_alive: KeepAlive::default(),
578            options: None,
579            think: None,
580            tools: vec![],
581        };
582
583        let serialized = serde_json::to_string(&request).unwrap();
584
585        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
586        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
587        assert_eq!(message_images.len(), 1);
588        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
589    }
590}