ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{sync::Arc, time::Duration};
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(untagged)]
 13pub enum KeepAlive {
 14    /// Keep model alive for N seconds
 15    Seconds(isize),
 16    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 17    Duration(String),
 18}
 19
 20impl KeepAlive {
 21    /// Keep model alive until a new model is loaded or until Ollama shuts down
 22    fn indefinite() -> Self {
 23        Self::Seconds(-1)
 24    }
 25}
 26
 27impl Default for KeepAlive {
 28    fn default() -> Self {
 29        Self::indefinite()
 30    }
 31}
 32
 33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 35pub struct Model {
 36    pub name: String,
 37    pub display_name: Option<String>,
 38    pub max_tokens: usize,
 39    pub keep_alive: Option<KeepAlive>,
 40    pub supports_tools: Option<bool>,
 41}
 42
 43fn get_max_tokens(name: &str) -> usize {
 44    /// Default context length for unknown models.
 45    const DEFAULT_TOKENS: usize = 4096;
 46    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 47    const MAXIMUM_TOKENS: usize = 16384;
 48
 49    match name.split(':').next().unwrap() {
 50        "phi" | "tinyllama" | "granite-code" => 2048,
 51        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 52        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 53        "codellama" | "starcoder2" => 16384,
 54        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
 55        | "dolphin-mixtral" => 32768,
 56        "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
 57        | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
 58        | "devstral" => 128000,
 59        _ => DEFAULT_TOKENS,
 60    }
 61    .clamp(1, MAXIMUM_TOKENS)
 62}
 63
 64impl Model {
 65    pub fn new(
 66        name: &str,
 67        display_name: Option<&str>,
 68        max_tokens: Option<usize>,
 69        supports_tools: Option<bool>,
 70    ) -> Self {
 71        Self {
 72            name: name.to_owned(),
 73            display_name: display_name
 74                .map(ToString::to_string)
 75                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 76            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 77            keep_alive: Some(KeepAlive::indefinite()),
 78            supports_tools,
 79        }
 80    }
 81
 82    pub fn id(&self) -> &str {
 83        &self.name
 84    }
 85
 86    pub fn display_name(&self) -> &str {
 87        self.display_name.as_ref().unwrap_or(&self.name)
 88    }
 89
 90    pub fn max_token_count(&self) -> usize {
 91        self.max_tokens
 92    }
 93}
 94
 95#[derive(Serialize, Deserialize, Debug)]
 96#[serde(tag = "role", rename_all = "lowercase")]
 97pub enum ChatMessage {
 98    Assistant {
 99        content: String,
100        tool_calls: Option<Vec<OllamaToolCall>>,
101    },
102    User {
103        content: String,
104    },
105    System {
106        content: String,
107    },
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111#[serde(rename_all = "lowercase")]
112pub enum OllamaToolCall {
113    Function(OllamaFunctionCall),
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117pub struct OllamaFunctionCall {
118    pub name: String,
119    pub arguments: Value,
120}
121
122#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
123pub struct OllamaFunctionTool {
124    pub name: String,
125    pub description: Option<String>,
126    pub parameters: Option<Value>,
127}
128
129#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
130#[serde(tag = "type", rename_all = "lowercase")]
131pub enum OllamaTool {
132    Function { function: OllamaFunctionTool },
133}
134
135#[derive(Serialize, Debug)]
136pub struct ChatRequest {
137    pub model: String,
138    pub messages: Vec<ChatMessage>,
139    pub stream: bool,
140    pub keep_alive: KeepAlive,
141    pub options: Option<ChatOptions>,
142    pub tools: Vec<OllamaTool>,
143}
144
145impl ChatRequest {
146    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
147        self.stream = false;
148        self.tools = tools;
149        self
150    }
151}
152
153// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
154#[derive(Serialize, Default, Debug)]
155pub struct ChatOptions {
156    pub num_ctx: Option<usize>,
157    pub num_predict: Option<isize>,
158    pub stop: Option<Vec<String>>,
159    pub temperature: Option<f32>,
160    pub top_p: Option<f32>,
161}
162
163#[derive(Deserialize, Debug)]
164pub struct ChatResponseDelta {
165    #[allow(unused)]
166    pub model: String,
167    #[allow(unused)]
168    pub created_at: String,
169    pub message: ChatMessage,
170    #[allow(unused)]
171    pub done_reason: Option<String>,
172    #[allow(unused)]
173    pub done: bool,
174}
175
176#[derive(Serialize, Deserialize)]
177pub struct LocalModelsResponse {
178    pub models: Vec<LocalModelListing>,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct LocalModelListing {
183    pub name: String,
184    pub modified_at: String,
185    pub size: u64,
186    pub digest: String,
187    pub details: ModelDetails,
188}
189
190#[derive(Serialize, Deserialize)]
191pub struct LocalModel {
192    pub modelfile: String,
193    pub parameters: String,
194    pub template: String,
195    pub details: ModelDetails,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct ModelDetails {
200    pub format: String,
201    pub family: String,
202    pub families: Option<Vec<String>>,
203    pub parameter_size: String,
204    pub quantization_level: String,
205}
206
207#[derive(Deserialize, Debug)]
208pub struct ModelShow {
209    #[serde(default)]
210    pub capabilities: Vec<String>,
211}
212
213impl ModelShow {
214    pub fn supports_tools(&self) -> bool {
215        // .contains expects &String, which would require an additional allocation
216        self.capabilities.iter().any(|v| v == "tools")
217    }
218}
219
220pub async fn complete(
221    client: &dyn HttpClient,
222    api_url: &str,
223    request: ChatRequest,
224) -> Result<ChatResponseDelta> {
225    let uri = format!("{api_url}/api/chat");
226    let request_builder = HttpRequest::builder()
227        .method(Method::POST)
228        .uri(uri)
229        .header("Content-Type", "application/json");
230
231    let serialized_request = serde_json::to_string(&request)?;
232    let request = request_builder.body(AsyncBody::from(serialized_request))?;
233
234    let mut response = client.send(request).await?;
235
236    let mut body = Vec::new();
237    response.body_mut().read_to_end(&mut body).await?;
238
239    if response.status().is_success() {
240        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
241        Ok(response_message)
242    } else {
243        let body_str = std::str::from_utf8(&body)?;
244        anyhow::bail!(
245            "Failed to connect to API: {} {}",
246            response.status(),
247            body_str
248        );
249    }
250}
251
252pub async fn stream_chat_completion(
253    client: &dyn HttpClient,
254    api_url: &str,
255    request: ChatRequest,
256) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
257    let uri = format!("{api_url}/api/chat");
258    let request_builder = http::Request::builder()
259        .method(Method::POST)
260        .uri(uri)
261        .header("Content-Type", "application/json");
262
263    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
264    let mut response = client.send(request).await?;
265    if response.status().is_success() {
266        let reader = BufReader::new(response.into_body());
267
268        Ok(reader
269            .lines()
270            .map(|line| match line {
271                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
272                Err(e) => Err(e.into()),
273            })
274            .boxed())
275    } else {
276        let mut body = String::new();
277        response.body_mut().read_to_string(&mut body).await?;
278        anyhow::bail!(
279            "Failed to connect to Ollama API: {} {}",
280            response.status(),
281            body,
282        );
283    }
284}
285
286pub async fn get_models(
287    client: &dyn HttpClient,
288    api_url: &str,
289    _: Option<Duration>,
290) -> Result<Vec<LocalModelListing>> {
291    let uri = format!("{api_url}/api/tags");
292    let request_builder = HttpRequest::builder()
293        .method(Method::GET)
294        .uri(uri)
295        .header("Accept", "application/json");
296
297    let request = request_builder.body(AsyncBody::default())?;
298
299    let mut response = client.send(request).await?;
300
301    let mut body = String::new();
302    response.body_mut().read_to_string(&mut body).await?;
303
304    anyhow::ensure!(
305        response.status().is_success(),
306        "Failed to connect to Ollama API: {} {}",
307        response.status(),
308        body,
309    );
310    let response: LocalModelsResponse =
311        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
312    Ok(response.models)
313}
314
315/// Fetch details of a model, used to determine model capabilities
316pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
317    let uri = format!("{api_url}/api/show");
318    let request = HttpRequest::builder()
319        .method(Method::POST)
320        .uri(uri)
321        .header("Content-Type", "application/json")
322        .body(AsyncBody::from(
323            serde_json::json!({ "model": model }).to_string(),
324        ))?;
325
326    let mut response = client.send(request).await?;
327    let mut body = String::new();
328    response.body_mut().read_to_string(&mut body).await?;
329
330    anyhow::ensure!(
331        response.status().is_success(),
332        "Failed to connect to Ollama API: {} {}",
333        response.status(),
334        body,
335    );
336    let details: ModelShow = serde_json::from_str(body.as_str())?;
337    Ok(details)
338}
339
340/// Sends an empty request to Ollama to trigger loading the model
341pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
342    let uri = format!("{api_url}/api/generate");
343    let request = HttpRequest::builder()
344        .method(Method::POST)
345        .uri(uri)
346        .header("Content-Type", "application/json")
347        .body(AsyncBody::from(
348            serde_json::json!({
349                "model": model,
350                "keep_alive": "15m",
351            })
352            .to_string(),
353        ))?;
354
355    let mut response = client.send(request).await?;
356
357    if response.status().is_success() {
358        Ok(())
359    } else {
360        let mut body = String::new();
361        response.body_mut().read_to_string(&mut body).await?;
362        anyhow::bail!(
363            "Failed to connect to Ollama API: {} {}",
364            response.status(),
365            body,
366        );
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn parse_completion() {
376        let response = serde_json::json!({
377        "model": "llama3.2",
378        "created_at": "2023-12-12T14:13:43.416799Z",
379        "message": {
380            "role": "assistant",
381            "content": "Hello! How are you today?"
382        },
383        "done": true,
384        "total_duration": 5191566416u64,
385        "load_duration": 2154458,
386        "prompt_eval_count": 26,
387        "prompt_eval_duration": 383809000,
388        "eval_count": 298,
389        "eval_duration": 4799921000u64
390        });
391        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
392    }
393
394    #[test]
395    fn parse_streaming_completion() {
396        let partial = serde_json::json!({
397        "model": "llama3.2",
398        "created_at": "2023-08-04T08:52:19.385406455-07:00",
399        "message": {
400            "role": "assistant",
401            "content": "The",
402            "images": null
403        },
404        "done": false
405        });
406
407        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
408
409        let last = serde_json::json!({
410        "model": "llama3.2",
411        "created_at": "2023-08-04T19:22:45.499127Z",
412        "message": {
413            "role": "assistant",
414            "content": ""
415        },
416        "done": true,
417        "total_duration": 4883583458u64,
418        "load_duration": 1334875,
419        "prompt_eval_count": 26,
420        "prompt_eval_duration": 342546000,
421        "eval_count": 282,
422        "eval_duration": 4535599000u64
423        });
424
425        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
426    }
427
428    #[test]
429    fn parse_tool_call() {
430        let response = serde_json::json!({
431            "model": "llama3.2:3b",
432            "created_at": "2025-04-28T20:02:02.140489Z",
433            "message": {
434                "role": "assistant",
435                "content": "",
436                "tool_calls": [
437                    {
438                        "function": {
439                            "name": "weather",
440                            "arguments": {
441                                "city": "london",
442                            }
443                        }
444                    }
445                ]
446            },
447            "done_reason": "stop",
448            "done": true,
449            "total_duration": 2758629166u64,
450            "load_duration": 1770059875,
451            "prompt_eval_count": 147,
452            "prompt_eval_duration": 684637583,
453            "eval_count": 16,
454            "eval_duration": 302561917,
455        });
456
457        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
458        match result.message {
459            ChatMessage::Assistant {
460                content,
461                tool_calls,
462            } => {
463                assert!(content.is_empty());
464                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
465            }
466            _ => panic!("Deserialized wrong role"),
467        }
468    }
469
470    #[test]
471    fn parse_show_model() {
472        let response = serde_json::json!({
473            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
474            "details": {
475                "parent_model": "",
476                "format": "gguf",
477                "family": "llama",
478                "families": ["llama"],
479                "parameter_size": "3.2B",
480                "quantization_level": "Q4_K_M"
481            },
482            "model_info": {
483                "general.architecture": "llama",
484                "general.basename": "Llama-3.2",
485                "general.file_type": 15,
486                "general.finetune": "Instruct",
487                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
488                "general.parameter_count": 3212749888u64,
489                "general.quantization_version": 2,
490                "general.size_label": "3B",
491                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
492                "general.type": "model",
493                "llama.attention.head_count": 24,
494                "llama.attention.head_count_kv": 8,
495                "llama.attention.key_length": 128,
496                "llama.attention.layer_norm_rms_epsilon": 0.00001,
497                "llama.attention.value_length": 128,
498                "llama.block_count": 28,
499                "llama.context_length": 131072,
500                "llama.embedding_length": 3072,
501                "llama.feed_forward_length": 8192,
502                "llama.rope.dimension_count": 128,
503                "llama.rope.freq_base": 500000,
504                "llama.vocab_size": 128256,
505                "tokenizer.ggml.bos_token_id": 128000,
506                "tokenizer.ggml.eos_token_id": 128009,
507                "tokenizer.ggml.merges": null,
508                "tokenizer.ggml.model": "gpt2",
509                "tokenizer.ggml.pre": "llama-bpe",
510                "tokenizer.ggml.token_type": null,
511                "tokenizer.ggml.tokens": null
512            },
513            "tensors": [
514                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
515                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
516            ],
517            "capabilities": ["completion", "tools"],
518            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
519        });
520
521        let result: ModelShow = serde_json::from_value(response).unwrap();
522        assert!(result.supports_tools());
523        assert!(result.capabilities.contains(&"tools".to_string()));
524        assert!(result.capabilities.contains(&"completion".to_string()));
525    }
526}