ollama.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{sync::Arc, time::Duration};
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(untagged)]
 13pub enum KeepAlive {
 14    /// Keep model alive for N seconds
 15    Seconds(isize),
 16    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 17    Duration(String),
 18}
 19
 20impl KeepAlive {
 21    /// Keep model alive until a new model is loaded or until Ollama shuts down
 22    fn indefinite() -> Self {
 23        Self::Seconds(-1)
 24    }
 25}
 26
 27impl Default for KeepAlive {
 28    fn default() -> Self {
 29        Self::indefinite()
 30    }
 31}
 32
 33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 35pub struct Model {
 36    pub name: String,
 37    pub display_name: Option<String>,
 38    pub max_tokens: usize,
 39    pub keep_alive: Option<KeepAlive>,
 40    pub supports_tools: Option<bool>,
 41}
 42
 43fn get_max_tokens(name: &str) -> usize {
 44    /// Default context length for unknown models.
 45    const DEFAULT_TOKENS: usize = 2048;
 46    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 47    const MAXIMUM_TOKENS: usize = 16384;
 48
 49    match name.split(':').next().unwrap() {
 50        "phi" | "tinyllama" | "granite-code" => 2048,
 51        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 52        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 53        "codellama" | "starcoder2" => 16384,
 54        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
 55        | "dolphin-mixtral" => 32768,
 56        "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
 57        | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" => {
 58            128000
 59        }
 60        _ => DEFAULT_TOKENS,
 61    }
 62    .clamp(1, MAXIMUM_TOKENS)
 63}
 64
 65impl Model {
 66    pub fn new(
 67        name: &str,
 68        display_name: Option<&str>,
 69        max_tokens: Option<usize>,
 70        supports_tools: Option<bool>,
 71    ) -> Self {
 72        Self {
 73            name: name.to_owned(),
 74            display_name: display_name
 75                .map(ToString::to_string)
 76                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 77            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 78            keep_alive: Some(KeepAlive::indefinite()),
 79            supports_tools,
 80        }
 81    }
 82
 83    pub fn id(&self) -> &str {
 84        &self.name
 85    }
 86
 87    pub fn display_name(&self) -> &str {
 88        self.display_name.as_ref().unwrap_or(&self.name)
 89    }
 90
 91    pub fn max_token_count(&self) -> usize {
 92        self.max_tokens
 93    }
 94}
 95
 96#[derive(Serialize, Deserialize, Debug)]
 97#[serde(tag = "role", rename_all = "lowercase")]
 98pub enum ChatMessage {
 99    Assistant {
100        content: String,
101        tool_calls: Option<Vec<OllamaToolCall>>,
102    },
103    User {
104        content: String,
105    },
106    System {
107        content: String,
108    },
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112#[serde(rename_all = "lowercase")]
113pub enum OllamaToolCall {
114    Function(OllamaFunctionCall),
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118pub struct OllamaFunctionCall {
119    pub name: String,
120    pub arguments: Value,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124pub struct OllamaFunctionTool {
125    pub name: String,
126    pub description: Option<String>,
127    pub parameters: Option<Value>,
128}
129
130#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
131#[serde(tag = "type", rename_all = "lowercase")]
132pub enum OllamaTool {
133    Function { function: OllamaFunctionTool },
134}
135
136#[derive(Serialize, Debug)]
137pub struct ChatRequest {
138    pub model: String,
139    pub messages: Vec<ChatMessage>,
140    pub stream: bool,
141    pub keep_alive: KeepAlive,
142    pub options: Option<ChatOptions>,
143    pub tools: Vec<OllamaTool>,
144}
145
146impl ChatRequest {
147    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
148        self.stream = false;
149        self.tools = tools;
150        self
151    }
152}
153
154// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
155#[derive(Serialize, Default, Debug)]
156pub struct ChatOptions {
157    pub num_ctx: Option<usize>,
158    pub num_predict: Option<isize>,
159    pub stop: Option<Vec<String>>,
160    pub temperature: Option<f32>,
161    pub top_p: Option<f32>,
162}
163
164#[derive(Deserialize, Debug)]
165pub struct ChatResponseDelta {
166    #[allow(unused)]
167    pub model: String,
168    #[allow(unused)]
169    pub created_at: String,
170    pub message: ChatMessage,
171    #[allow(unused)]
172    pub done_reason: Option<String>,
173    #[allow(unused)]
174    pub done: bool,
175}
176
177#[derive(Serialize, Deserialize)]
178pub struct LocalModelsResponse {
179    pub models: Vec<LocalModelListing>,
180}
181
182#[derive(Serialize, Deserialize)]
183pub struct LocalModelListing {
184    pub name: String,
185    pub modified_at: String,
186    pub size: u64,
187    pub digest: String,
188    pub details: ModelDetails,
189}
190
191#[derive(Serialize, Deserialize)]
192pub struct LocalModel {
193    pub modelfile: String,
194    pub parameters: String,
195    pub template: String,
196    pub details: ModelDetails,
197}
198
199#[derive(Serialize, Deserialize)]
200pub struct ModelDetails {
201    pub format: String,
202    pub family: String,
203    pub families: Option<Vec<String>>,
204    pub parameter_size: String,
205    pub quantization_level: String,
206}
207
208#[derive(Deserialize, Debug)]
209pub struct ModelShow {
210    #[serde(default)]
211    pub capabilities: Vec<String>,
212}
213
214impl ModelShow {
215    pub fn supports_tools(&self) -> bool {
216        // .contains expects &String, which would require an additional allocation
217        self.capabilities.iter().any(|v| v == "tools")
218    }
219}
220
221pub async fn complete(
222    client: &dyn HttpClient,
223    api_url: &str,
224    request: ChatRequest,
225) -> Result<ChatResponseDelta> {
226    let uri = format!("{api_url}/api/chat");
227    let request_builder = HttpRequest::builder()
228        .method(Method::POST)
229        .uri(uri)
230        .header("Content-Type", "application/json");
231
232    let serialized_request = serde_json::to_string(&request)?;
233    let request = request_builder.body(AsyncBody::from(serialized_request))?;
234
235    let mut response = client.send(request).await?;
236
237    let mut body = Vec::new();
238    response.body_mut().read_to_end(&mut body).await?;
239
240    if response.status().is_success() {
241        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
242        Ok(response_message)
243    } else {
244        let body_str = std::str::from_utf8(&body)?;
245        Err(anyhow!(
246            "Failed to connect to API: {} {}",
247            response.status(),
248            body_str
249        ))
250    }
251}
252
253pub async fn stream_chat_completion(
254    client: &dyn HttpClient,
255    api_url: &str,
256    request: ChatRequest,
257) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
258    let uri = format!("{api_url}/api/chat");
259    let request_builder = http::Request::builder()
260        .method(Method::POST)
261        .uri(uri)
262        .header("Content-Type", "application/json");
263
264    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
265    let mut response = client.send(request).await?;
266    if response.status().is_success() {
267        let reader = BufReader::new(response.into_body());
268
269        Ok(reader
270            .lines()
271            .map(|line| match line {
272                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
273                Err(e) => Err(e.into()),
274            })
275            .boxed())
276    } else {
277        let mut body = String::new();
278        response.body_mut().read_to_string(&mut body).await?;
279
280        Err(anyhow!(
281            "Failed to connect to Ollama API: {} {}",
282            response.status(),
283            body,
284        ))
285    }
286}
287
288pub async fn get_models(
289    client: &dyn HttpClient,
290    api_url: &str,
291    _: Option<Duration>,
292) -> Result<Vec<LocalModelListing>> {
293    let uri = format!("{api_url}/api/tags");
294    let request_builder = HttpRequest::builder()
295        .method(Method::GET)
296        .uri(uri)
297        .header("Accept", "application/json");
298
299    let request = request_builder.body(AsyncBody::default())?;
300
301    let mut response = client.send(request).await?;
302
303    let mut body = String::new();
304    response.body_mut().read_to_string(&mut body).await?;
305
306    if response.status().is_success() {
307        let response: LocalModelsResponse =
308            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
309
310        Ok(response.models)
311    } else {
312        Err(anyhow!(
313            "Failed to connect to Ollama API: {} {}",
314            response.status(),
315            body,
316        ))
317    }
318}
319
320/// Fetch details of a model, used to determine model capabilities
321pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
322    let uri = format!("{api_url}/api/show");
323    let request = HttpRequest::builder()
324        .method(Method::POST)
325        .uri(uri)
326        .header("Content-Type", "application/json")
327        .body(AsyncBody::from(
328            serde_json::json!({ "model": model }).to_string(),
329        ))?;
330
331    let mut response = client.send(request).await?;
332    let mut body = String::new();
333    response.body_mut().read_to_string(&mut body).await?;
334
335    if response.status().is_success() {
336        let details: ModelShow = serde_json::from_str(body.as_str())?;
337        Ok(details)
338    } else {
339        Err(anyhow!(
340            "Failed to connect to Ollama API: {} {}",
341            response.status(),
342            body,
343        ))
344    }
345}
346
347/// Sends an empty request to Ollama to trigger loading the model
348pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
349    let uri = format!("{api_url}/api/generate");
350    let request = HttpRequest::builder()
351        .method(Method::POST)
352        .uri(uri)
353        .header("Content-Type", "application/json")
354        .body(AsyncBody::from(
355            serde_json::json!({
356                "model": model,
357                "keep_alive": "15m",
358            })
359            .to_string(),
360        ))?;
361
362    let mut response = client.send(request).await?;
363
364    if response.status().is_success() {
365        Ok(())
366    } else {
367        let mut body = String::new();
368        response.body_mut().read_to_string(&mut body).await?;
369
370        Err(anyhow!(
371            "Failed to connect to Ollama API: {} {}",
372            response.status(),
373            body,
374        ))
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn parse_completion() {
384        let response = serde_json::json!({
385        "model": "llama3.2",
386        "created_at": "2023-12-12T14:13:43.416799Z",
387        "message": {
388            "role": "assistant",
389            "content": "Hello! How are you today?"
390        },
391        "done": true,
392        "total_duration": 5191566416u64,
393        "load_duration": 2154458,
394        "prompt_eval_count": 26,
395        "prompt_eval_duration": 383809000,
396        "eval_count": 298,
397        "eval_duration": 4799921000u64
398        });
399        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
400    }
401
402    #[test]
403    fn parse_streaming_completion() {
404        let partial = serde_json::json!({
405        "model": "llama3.2",
406        "created_at": "2023-08-04T08:52:19.385406455-07:00",
407        "message": {
408            "role": "assistant",
409            "content": "The",
410            "images": null
411        },
412        "done": false
413        });
414
415        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
416
417        let last = serde_json::json!({
418        "model": "llama3.2",
419        "created_at": "2023-08-04T19:22:45.499127Z",
420        "message": {
421            "role": "assistant",
422            "content": ""
423        },
424        "done": true,
425        "total_duration": 4883583458u64,
426        "load_duration": 1334875,
427        "prompt_eval_count": 26,
428        "prompt_eval_duration": 342546000,
429        "eval_count": 282,
430        "eval_duration": 4535599000u64
431        });
432
433        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
434    }
435
436    #[test]
437    fn parse_tool_call() {
438        let response = serde_json::json!({
439            "model": "llama3.2:3b",
440            "created_at": "2025-04-28T20:02:02.140489Z",
441            "message": {
442                "role": "assistant",
443                "content": "",
444                "tool_calls": [
445                    {
446                        "function": {
447                            "name": "weather",
448                            "arguments": {
449                                "city": "london",
450                            }
451                        }
452                    }
453                ]
454            },
455            "done_reason": "stop",
456            "done": true,
457            "total_duration": 2758629166u64,
458            "load_duration": 1770059875,
459            "prompt_eval_count": 147,
460            "prompt_eval_duration": 684637583,
461            "eval_count": 16,
462            "eval_duration": 302561917,
463        });
464
465        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
466        match result.message {
467            ChatMessage::Assistant {
468                content,
469                tool_calls,
470            } => {
471                assert!(content.is_empty());
472                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
473            }
474            _ => panic!("Deserialized wrong role"),
475        }
476    }
477
478    #[test]
479    fn parse_show_model() {
480        let response = serde_json::json!({
481            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
482            "details": {
483                "parent_model": "",
484                "format": "gguf",
485                "family": "llama",
486                "families": ["llama"],
487                "parameter_size": "3.2B",
488                "quantization_level": "Q4_K_M"
489            },
490            "model_info": {
491                "general.architecture": "llama",
492                "general.basename": "Llama-3.2",
493                "general.file_type": 15,
494                "general.finetune": "Instruct",
495                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
496                "general.parameter_count": 3212749888u64,
497                "general.quantization_version": 2,
498                "general.size_label": "3B",
499                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
500                "general.type": "model",
501                "llama.attention.head_count": 24,
502                "llama.attention.head_count_kv": 8,
503                "llama.attention.key_length": 128,
504                "llama.attention.layer_norm_rms_epsilon": 0.00001,
505                "llama.attention.value_length": 128,
506                "llama.block_count": 28,
507                "llama.context_length": 131072,
508                "llama.embedding_length": 3072,
509                "llama.feed_forward_length": 8192,
510                "llama.rope.dimension_count": 128,
511                "llama.rope.freq_base": 500000,
512                "llama.vocab_size": 128256,
513                "tokenizer.ggml.bos_token_id": 128000,
514                "tokenizer.ggml.eos_token_id": 128009,
515                "tokenizer.ggml.merges": null,
516                "tokenizer.ggml.model": "gpt2",
517                "tokenizer.ggml.pre": "llama-bpe",
518                "tokenizer.ggml.token_type": null,
519                "tokenizer.ggml.tokens": null
520            },
521            "tensors": [
522                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
523                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
524            ],
525            "capabilities": ["completion", "tools"],
526            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
527        });
528
529        let result: ModelShow = serde_json::from_value(response).unwrap();
530        assert!(result.supports_tools());
531        assert!(result.capabilities.contains(&"tools".to_string()));
532        assert!(result.capabilities.contains(&"completion".to_string()));
533    }
534}