ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{sync::Arc, time::Duration};
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(untagged)]
 13pub enum KeepAlive {
 14    /// Keep model alive for N seconds
 15    Seconds(isize),
 16    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 17    Duration(String),
 18}
 19
 20impl KeepAlive {
 21    /// Keep model alive until a new model is loaded or until Ollama shuts down
 22    fn indefinite() -> Self {
 23        Self::Seconds(-1)
 24    }
 25}
 26
 27impl Default for KeepAlive {
 28    fn default() -> Self {
 29        Self::indefinite()
 30    }
 31}
 32
 33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 35pub struct Model {
 36    pub name: String,
 37    pub display_name: Option<String>,
 38    pub max_tokens: usize,
 39    pub keep_alive: Option<KeepAlive>,
 40    pub supports_tools: Option<bool>,
 41    pub supports_thinking: Option<bool>,
 42}
 43
 44fn get_max_tokens(name: &str) -> usize {
 45    /// Default context length for unknown models.
 46    const DEFAULT_TOKENS: usize = 4096;
 47    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 48    const MAXIMUM_TOKENS: usize = 16384;
 49
 50    match name.split(':').next().unwrap() {
 51        "phi" | "tinyllama" | "granite-code" => 2048,
 52        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 53        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 54        "codellama" | "starcoder2" => 16384,
 55        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
 56        | "dolphin-mixtral" => 32768,
 57        "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
 58        | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
 59        | "devstral" => 128000,
 60        _ => DEFAULT_TOKENS,
 61    }
 62    .clamp(1, MAXIMUM_TOKENS)
 63}
 64
 65impl Model {
 66    pub fn new(
 67        name: &str,
 68        display_name: Option<&str>,
 69        max_tokens: Option<usize>,
 70        supports_tools: Option<bool>,
 71        supports_thinking: Option<bool>,
 72    ) -> Self {
 73        Self {
 74            name: name.to_owned(),
 75            display_name: display_name
 76                .map(ToString::to_string)
 77                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 78            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 79            keep_alive: Some(KeepAlive::indefinite()),
 80            supports_tools,
 81            supports_thinking,
 82        }
 83    }
 84
 85    pub fn id(&self) -> &str {
 86        &self.name
 87    }
 88
 89    pub fn display_name(&self) -> &str {
 90        self.display_name.as_ref().unwrap_or(&self.name)
 91    }
 92
 93    pub fn max_token_count(&self) -> usize {
 94        self.max_tokens
 95    }
 96}
 97
 98#[derive(Serialize, Deserialize, Debug)]
 99#[serde(tag = "role", rename_all = "lowercase")]
100pub enum ChatMessage {
101    Assistant {
102        content: String,
103        tool_calls: Option<Vec<OllamaToolCall>>,
104        thinking: Option<String>,
105    },
106    User {
107        content: String,
108    },
109    System {
110        content: String,
111    },
112}
113
114#[derive(Serialize, Deserialize, Debug)]
115#[serde(rename_all = "lowercase")]
116pub enum OllamaToolCall {
117    Function(OllamaFunctionCall),
118}
119
120#[derive(Serialize, Deserialize, Debug)]
121pub struct OllamaFunctionCall {
122    pub name: String,
123    pub arguments: Value,
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127pub struct OllamaFunctionTool {
128    pub name: String,
129    pub description: Option<String>,
130    pub parameters: Option<Value>,
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum OllamaTool {
136    Function { function: OllamaFunctionTool },
137}
138
139#[derive(Serialize, Debug)]
140pub struct ChatRequest {
141    pub model: String,
142    pub messages: Vec<ChatMessage>,
143    pub stream: bool,
144    pub keep_alive: KeepAlive,
145    pub options: Option<ChatOptions>,
146    pub tools: Vec<OllamaTool>,
147    pub think: Option<bool>,
148}
149
150impl ChatRequest {
151    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
152        self.stream = false;
153        self.tools = tools;
154        self
155    }
156}
157
158// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
159#[derive(Serialize, Default, Debug)]
160pub struct ChatOptions {
161    pub num_ctx: Option<usize>,
162    pub num_predict: Option<isize>,
163    pub stop: Option<Vec<String>>,
164    pub temperature: Option<f32>,
165    pub top_p: Option<f32>,
166}
167
168#[derive(Deserialize, Debug)]
169pub struct ChatResponseDelta {
170    #[allow(unused)]
171    pub model: String,
172    #[allow(unused)]
173    pub created_at: String,
174    pub message: ChatMessage,
175    #[allow(unused)]
176    pub done_reason: Option<String>,
177    #[allow(unused)]
178    pub done: bool,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct LocalModelsResponse {
183    pub models: Vec<LocalModelListing>,
184}
185
186#[derive(Serialize, Deserialize)]
187pub struct LocalModelListing {
188    pub name: String,
189    pub modified_at: String,
190    pub size: u64,
191    pub digest: String,
192    pub details: ModelDetails,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct LocalModel {
197    pub modelfile: String,
198    pub parameters: String,
199    pub template: String,
200    pub details: ModelDetails,
201}
202
203#[derive(Serialize, Deserialize)]
204pub struct ModelDetails {
205    pub format: String,
206    pub family: String,
207    pub families: Option<Vec<String>>,
208    pub parameter_size: String,
209    pub quantization_level: String,
210}
211
212#[derive(Deserialize, Debug)]
213pub struct ModelShow {
214    #[serde(default)]
215    pub capabilities: Vec<String>,
216}
217
218impl ModelShow {
219    pub fn supports_tools(&self) -> bool {
220        // .contains expects &String, which would require an additional allocation
221        self.capabilities.iter().any(|v| v == "tools")
222    }
223
224    pub fn supports_thinking(&self) -> bool {
225        self.capabilities.iter().any(|v| v == "thinking")
226    }
227}
228
229pub async fn complete(
230    client: &dyn HttpClient,
231    api_url: &str,
232    request: ChatRequest,
233) -> Result<ChatResponseDelta> {
234    let uri = format!("{api_url}/api/chat");
235    let request_builder = HttpRequest::builder()
236        .method(Method::POST)
237        .uri(uri)
238        .header("Content-Type", "application/json");
239
240    let serialized_request = serde_json::to_string(&request)?;
241    let request = request_builder.body(AsyncBody::from(serialized_request))?;
242
243    let mut response = client.send(request).await?;
244
245    let mut body = Vec::new();
246    response.body_mut().read_to_end(&mut body).await?;
247
248    if response.status().is_success() {
249        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
250        Ok(response_message)
251    } else {
252        let body_str = std::str::from_utf8(&body)?;
253        anyhow::bail!(
254            "Failed to connect to API: {} {}",
255            response.status(),
256            body_str
257        );
258    }
259}
260
261pub async fn stream_chat_completion(
262    client: &dyn HttpClient,
263    api_url: &str,
264    request: ChatRequest,
265) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
266    let uri = format!("{api_url}/api/chat");
267    let request_builder = http::Request::builder()
268        .method(Method::POST)
269        .uri(uri)
270        .header("Content-Type", "application/json");
271
272    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
273    let mut response = client.send(request).await?;
274    if response.status().is_success() {
275        let reader = BufReader::new(response.into_body());
276
277        Ok(reader
278            .lines()
279            .map(|line| match line {
280                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
281                Err(e) => Err(e.into()),
282            })
283            .boxed())
284    } else {
285        let mut body = String::new();
286        response.body_mut().read_to_string(&mut body).await?;
287        anyhow::bail!(
288            "Failed to connect to Ollama API: {} {}",
289            response.status(),
290            body,
291        );
292    }
293}
294
295pub async fn get_models(
296    client: &dyn HttpClient,
297    api_url: &str,
298    _: Option<Duration>,
299) -> Result<Vec<LocalModelListing>> {
300    let uri = format!("{api_url}/api/tags");
301    let request_builder = HttpRequest::builder()
302        .method(Method::GET)
303        .uri(uri)
304        .header("Accept", "application/json");
305
306    let request = request_builder.body(AsyncBody::default())?;
307
308    let mut response = client.send(request).await?;
309
310    let mut body = String::new();
311    response.body_mut().read_to_string(&mut body).await?;
312
313    anyhow::ensure!(
314        response.status().is_success(),
315        "Failed to connect to Ollama API: {} {}",
316        response.status(),
317        body,
318    );
319    let response: LocalModelsResponse =
320        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
321    Ok(response.models)
322}
323
324/// Fetch details of a model, used to determine model capabilities
325pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
326    let uri = format!("{api_url}/api/show");
327    let request = HttpRequest::builder()
328        .method(Method::POST)
329        .uri(uri)
330        .header("Content-Type", "application/json")
331        .body(AsyncBody::from(
332            serde_json::json!({ "model": model }).to_string(),
333        ))?;
334
335    let mut response = client.send(request).await?;
336    let mut body = String::new();
337    response.body_mut().read_to_string(&mut body).await?;
338
339    anyhow::ensure!(
340        response.status().is_success(),
341        "Failed to connect to Ollama API: {} {}",
342        response.status(),
343        body,
344    );
345    let details: ModelShow = serde_json::from_str(body.as_str())?;
346    Ok(details)
347}
348
349/// Sends an empty request to Ollama to trigger loading the model
350pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
351    let uri = format!("{api_url}/api/generate");
352    let request = HttpRequest::builder()
353        .method(Method::POST)
354        .uri(uri)
355        .header("Content-Type", "application/json")
356        .body(AsyncBody::from(
357            serde_json::json!({
358                "model": model,
359                "keep_alive": "15m",
360            })
361            .to_string(),
362        ))?;
363
364    let mut response = client.send(request).await?;
365
366    if response.status().is_success() {
367        Ok(())
368    } else {
369        let mut body = String::new();
370        response.body_mut().read_to_string(&mut body).await?;
371        anyhow::bail!(
372            "Failed to connect to Ollama API: {} {}",
373            response.status(),
374            body,
375        );
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn parse_completion() {
385        let response = serde_json::json!({
386        "model": "llama3.2",
387        "created_at": "2023-12-12T14:13:43.416799Z",
388        "message": {
389            "role": "assistant",
390            "content": "Hello! How are you today?"
391        },
392        "done": true,
393        "total_duration": 5191566416u64,
394        "load_duration": 2154458,
395        "prompt_eval_count": 26,
396        "prompt_eval_duration": 383809000,
397        "eval_count": 298,
398        "eval_duration": 4799921000u64
399        });
400        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
401    }
402
403    #[test]
404    fn parse_streaming_completion() {
405        let partial = serde_json::json!({
406        "model": "llama3.2",
407        "created_at": "2023-08-04T08:52:19.385406455-07:00",
408        "message": {
409            "role": "assistant",
410            "content": "The",
411            "images": null
412        },
413        "done": false
414        });
415
416        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
417
418        let last = serde_json::json!({
419        "model": "llama3.2",
420        "created_at": "2023-08-04T19:22:45.499127Z",
421        "message": {
422            "role": "assistant",
423            "content": ""
424        },
425        "done": true,
426        "total_duration": 4883583458u64,
427        "load_duration": 1334875,
428        "prompt_eval_count": 26,
429        "prompt_eval_duration": 342546000,
430        "eval_count": 282,
431        "eval_duration": 4535599000u64
432        });
433
434        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
435    }
436
437    #[test]
438    fn parse_tool_call() {
439        let response = serde_json::json!({
440            "model": "llama3.2:3b",
441            "created_at": "2025-04-28T20:02:02.140489Z",
442            "message": {
443                "role": "assistant",
444                "content": "",
445                "tool_calls": [
446                    {
447                        "function": {
448                            "name": "weather",
449                            "arguments": {
450                                "city": "london",
451                            }
452                        }
453                    }
454                ]
455            },
456            "done_reason": "stop",
457            "done": true,
458            "total_duration": 2758629166u64,
459            "load_duration": 1770059875,
460            "prompt_eval_count": 147,
461            "prompt_eval_duration": 684637583,
462            "eval_count": 16,
463            "eval_duration": 302561917,
464        });
465
466        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
467        match result.message {
468            ChatMessage::Assistant {
469                content,
470                tool_calls,
471                thinking,
472            } => {
473                assert!(content.is_empty());
474                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
475                assert!(thinking.is_none());
476            }
477            _ => panic!("Deserialized wrong role"),
478        }
479    }
480
481    #[test]
482    fn parse_show_model() {
483        let response = serde_json::json!({
484            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
485            "details": {
486                "parent_model": "",
487                "format": "gguf",
488                "family": "llama",
489                "families": ["llama"],
490                "parameter_size": "3.2B",
491                "quantization_level": "Q4_K_M"
492            },
493            "model_info": {
494                "general.architecture": "llama",
495                "general.basename": "Llama-3.2",
496                "general.file_type": 15,
497                "general.finetune": "Instruct",
498                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
499                "general.parameter_count": 3212749888u64,
500                "general.quantization_version": 2,
501                "general.size_label": "3B",
502                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
503                "general.type": "model",
504                "llama.attention.head_count": 24,
505                "llama.attention.head_count_kv": 8,
506                "llama.attention.key_length": 128,
507                "llama.attention.layer_norm_rms_epsilon": 0.00001,
508                "llama.attention.value_length": 128,
509                "llama.block_count": 28,
510                "llama.context_length": 131072,
511                "llama.embedding_length": 3072,
512                "llama.feed_forward_length": 8192,
513                "llama.rope.dimension_count": 128,
514                "llama.rope.freq_base": 500000,
515                "llama.vocab_size": 128256,
516                "tokenizer.ggml.bos_token_id": 128000,
517                "tokenizer.ggml.eos_token_id": 128009,
518                "tokenizer.ggml.merges": null,
519                "tokenizer.ggml.model": "gpt2",
520                "tokenizer.ggml.pre": "llama-bpe",
521                "tokenizer.ggml.token_type": null,
522                "tokenizer.ggml.tokens": null
523            },
524            "tensors": [
525                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
526                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
527            ],
528            "capabilities": ["completion", "tools"],
529            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
530        });
531
532        let result: ModelShow = serde_json::from_value(response).unwrap();
533        assert!(result.supports_tools());
534        assert!(result.capabilities.contains(&"tools".to_string()));
535        assert!(result.capabilities.contains(&"completion".to_string()));
536    }
537}