1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{sync::Arc, time::Duration};
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(untagged)]
13pub enum KeepAlive {
14 /// Keep model alive for N seconds
15 Seconds(isize),
16 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
17 Duration(String),
18}
19
20impl KeepAlive {
21 /// Keep model alive until a new model is loaded or until Ollama shuts down
22 fn indefinite() -> Self {
23 Self::Seconds(-1)
24 }
25}
26
27impl Default for KeepAlive {
28 fn default() -> Self {
29 Self::indefinite()
30 }
31}
32
33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
35pub struct Model {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: usize,
39 pub keep_alive: Option<KeepAlive>,
40 pub supports_tools: Option<bool>,
41 pub supports_thinking: Option<bool>,
42}
43
44fn get_max_tokens(name: &str) -> usize {
45 /// Default context length for unknown models.
46 const DEFAULT_TOKENS: usize = 4096;
47 /// Magic number. Lets many Ollama models work with ~16GB of ram.
48 const MAXIMUM_TOKENS: usize = 16384;
49
50 match name.split(':').next().unwrap() {
51 "phi" | "tinyllama" | "granite-code" => 2048,
52 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
53 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
54 "codellama" | "starcoder2" => 16384,
55 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
56 | "dolphin-mixtral" => 32768,
57 "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
58 | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
59 | "devstral" => 128000,
60 _ => DEFAULT_TOKENS,
61 }
62 .clamp(1, MAXIMUM_TOKENS)
63}
64
65impl Model {
66 pub fn new(
67 name: &str,
68 display_name: Option<&str>,
69 max_tokens: Option<usize>,
70 supports_tools: Option<bool>,
71 supports_thinking: Option<bool>,
72 ) -> Self {
73 Self {
74 name: name.to_owned(),
75 display_name: display_name
76 .map(ToString::to_string)
77 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
78 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
79 keep_alive: Some(KeepAlive::indefinite()),
80 supports_tools,
81 supports_thinking,
82 }
83 }
84
85 pub fn id(&self) -> &str {
86 &self.name
87 }
88
89 pub fn display_name(&self) -> &str {
90 self.display_name.as_ref().unwrap_or(&self.name)
91 }
92
93 pub fn max_token_count(&self) -> usize {
94 self.max_tokens
95 }
96}
97
98#[derive(Serialize, Deserialize, Debug)]
99#[serde(tag = "role", rename_all = "lowercase")]
100pub enum ChatMessage {
101 Assistant {
102 content: String,
103 tool_calls: Option<Vec<OllamaToolCall>>,
104 thinking: Option<String>,
105 },
106 User {
107 content: String,
108 },
109 System {
110 content: String,
111 },
112}
113
114#[derive(Serialize, Deserialize, Debug)]
115#[serde(rename_all = "lowercase")]
116pub enum OllamaToolCall {
117 Function(OllamaFunctionCall),
118}
119
120#[derive(Serialize, Deserialize, Debug)]
121pub struct OllamaFunctionCall {
122 pub name: String,
123 pub arguments: Value,
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127pub struct OllamaFunctionTool {
128 pub name: String,
129 pub description: Option<String>,
130 pub parameters: Option<Value>,
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum OllamaTool {
136 Function { function: OllamaFunctionTool },
137}
138
139#[derive(Serialize, Debug)]
140pub struct ChatRequest {
141 pub model: String,
142 pub messages: Vec<ChatMessage>,
143 pub stream: bool,
144 pub keep_alive: KeepAlive,
145 pub options: Option<ChatOptions>,
146 pub tools: Vec<OllamaTool>,
147 pub think: Option<bool>,
148}
149
150impl ChatRequest {
151 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
152 self.stream = false;
153 self.tools = tools;
154 self
155 }
156}
157
158// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
159#[derive(Serialize, Default, Debug)]
160pub struct ChatOptions {
161 pub num_ctx: Option<usize>,
162 pub num_predict: Option<isize>,
163 pub stop: Option<Vec<String>>,
164 pub temperature: Option<f32>,
165 pub top_p: Option<f32>,
166}
167
168#[derive(Deserialize, Debug)]
169pub struct ChatResponseDelta {
170 #[allow(unused)]
171 pub model: String,
172 #[allow(unused)]
173 pub created_at: String,
174 pub message: ChatMessage,
175 #[allow(unused)]
176 pub done_reason: Option<String>,
177 #[allow(unused)]
178 pub done: bool,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct LocalModelsResponse {
183 pub models: Vec<LocalModelListing>,
184}
185
186#[derive(Serialize, Deserialize)]
187pub struct LocalModelListing {
188 pub name: String,
189 pub modified_at: String,
190 pub size: u64,
191 pub digest: String,
192 pub details: ModelDetails,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct LocalModel {
197 pub modelfile: String,
198 pub parameters: String,
199 pub template: String,
200 pub details: ModelDetails,
201}
202
203#[derive(Serialize, Deserialize)]
204pub struct ModelDetails {
205 pub format: String,
206 pub family: String,
207 pub families: Option<Vec<String>>,
208 pub parameter_size: String,
209 pub quantization_level: String,
210}
211
212#[derive(Deserialize, Debug)]
213pub struct ModelShow {
214 #[serde(default)]
215 pub capabilities: Vec<String>,
216}
217
218impl ModelShow {
219 pub fn supports_tools(&self) -> bool {
220 // .contains expects &String, which would require an additional allocation
221 self.capabilities.iter().any(|v| v == "tools")
222 }
223
224 pub fn supports_thinking(&self) -> bool {
225 self.capabilities.iter().any(|v| v == "thinking")
226 }
227}
228
229pub async fn complete(
230 client: &dyn HttpClient,
231 api_url: &str,
232 request: ChatRequest,
233) -> Result<ChatResponseDelta> {
234 let uri = format!("{api_url}/api/chat");
235 let request_builder = HttpRequest::builder()
236 .method(Method::POST)
237 .uri(uri)
238 .header("Content-Type", "application/json");
239
240 let serialized_request = serde_json::to_string(&request)?;
241 let request = request_builder.body(AsyncBody::from(serialized_request))?;
242
243 let mut response = client.send(request).await?;
244
245 let mut body = Vec::new();
246 response.body_mut().read_to_end(&mut body).await?;
247
248 if response.status().is_success() {
249 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
250 Ok(response_message)
251 } else {
252 let body_str = std::str::from_utf8(&body)?;
253 anyhow::bail!(
254 "Failed to connect to API: {} {}",
255 response.status(),
256 body_str
257 );
258 }
259}
260
261pub async fn stream_chat_completion(
262 client: &dyn HttpClient,
263 api_url: &str,
264 request: ChatRequest,
265) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
266 let uri = format!("{api_url}/api/chat");
267 let request_builder = http::Request::builder()
268 .method(Method::POST)
269 .uri(uri)
270 .header("Content-Type", "application/json");
271
272 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
273 let mut response = client.send(request).await?;
274 if response.status().is_success() {
275 let reader = BufReader::new(response.into_body());
276
277 Ok(reader
278 .lines()
279 .map(|line| match line {
280 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
281 Err(e) => Err(e.into()),
282 })
283 .boxed())
284 } else {
285 let mut body = String::new();
286 response.body_mut().read_to_string(&mut body).await?;
287 anyhow::bail!(
288 "Failed to connect to Ollama API: {} {}",
289 response.status(),
290 body,
291 );
292 }
293}
294
295pub async fn get_models(
296 client: &dyn HttpClient,
297 api_url: &str,
298 _: Option<Duration>,
299) -> Result<Vec<LocalModelListing>> {
300 let uri = format!("{api_url}/api/tags");
301 let request_builder = HttpRequest::builder()
302 .method(Method::GET)
303 .uri(uri)
304 .header("Accept", "application/json");
305
306 let request = request_builder.body(AsyncBody::default())?;
307
308 let mut response = client.send(request).await?;
309
310 let mut body = String::new();
311 response.body_mut().read_to_string(&mut body).await?;
312
313 anyhow::ensure!(
314 response.status().is_success(),
315 "Failed to connect to Ollama API: {} {}",
316 response.status(),
317 body,
318 );
319 let response: LocalModelsResponse =
320 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
321 Ok(response.models)
322}
323
324/// Fetch details of a model, used to determine model capabilities
325pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
326 let uri = format!("{api_url}/api/show");
327 let request = HttpRequest::builder()
328 .method(Method::POST)
329 .uri(uri)
330 .header("Content-Type", "application/json")
331 .body(AsyncBody::from(
332 serde_json::json!({ "model": model }).to_string(),
333 ))?;
334
335 let mut response = client.send(request).await?;
336 let mut body = String::new();
337 response.body_mut().read_to_string(&mut body).await?;
338
339 anyhow::ensure!(
340 response.status().is_success(),
341 "Failed to connect to Ollama API: {} {}",
342 response.status(),
343 body,
344 );
345 let details: ModelShow = serde_json::from_str(body.as_str())?;
346 Ok(details)
347}
348
349/// Sends an empty request to Ollama to trigger loading the model
350pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
351 let uri = format!("{api_url}/api/generate");
352 let request = HttpRequest::builder()
353 .method(Method::POST)
354 .uri(uri)
355 .header("Content-Type", "application/json")
356 .body(AsyncBody::from(
357 serde_json::json!({
358 "model": model,
359 "keep_alive": "15m",
360 })
361 .to_string(),
362 ))?;
363
364 let mut response = client.send(request).await?;
365
366 if response.status().is_success() {
367 Ok(())
368 } else {
369 let mut body = String::new();
370 response.body_mut().read_to_string(&mut body).await?;
371 anyhow::bail!(
372 "Failed to connect to Ollama API: {} {}",
373 response.status(),
374 body,
375 );
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn parse_completion() {
385 let response = serde_json::json!({
386 "model": "llama3.2",
387 "created_at": "2023-12-12T14:13:43.416799Z",
388 "message": {
389 "role": "assistant",
390 "content": "Hello! How are you today?"
391 },
392 "done": true,
393 "total_duration": 5191566416u64,
394 "load_duration": 2154458,
395 "prompt_eval_count": 26,
396 "prompt_eval_duration": 383809000,
397 "eval_count": 298,
398 "eval_duration": 4799921000u64
399 });
400 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
401 }
402
403 #[test]
404 fn parse_streaming_completion() {
405 let partial = serde_json::json!({
406 "model": "llama3.2",
407 "created_at": "2023-08-04T08:52:19.385406455-07:00",
408 "message": {
409 "role": "assistant",
410 "content": "The",
411 "images": null
412 },
413 "done": false
414 });
415
416 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
417
418 let last = serde_json::json!({
419 "model": "llama3.2",
420 "created_at": "2023-08-04T19:22:45.499127Z",
421 "message": {
422 "role": "assistant",
423 "content": ""
424 },
425 "done": true,
426 "total_duration": 4883583458u64,
427 "load_duration": 1334875,
428 "prompt_eval_count": 26,
429 "prompt_eval_duration": 342546000,
430 "eval_count": 282,
431 "eval_duration": 4535599000u64
432 });
433
434 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
435 }
436
437 #[test]
438 fn parse_tool_call() {
439 let response = serde_json::json!({
440 "model": "llama3.2:3b",
441 "created_at": "2025-04-28T20:02:02.140489Z",
442 "message": {
443 "role": "assistant",
444 "content": "",
445 "tool_calls": [
446 {
447 "function": {
448 "name": "weather",
449 "arguments": {
450 "city": "london",
451 }
452 }
453 }
454 ]
455 },
456 "done_reason": "stop",
457 "done": true,
458 "total_duration": 2758629166u64,
459 "load_duration": 1770059875,
460 "prompt_eval_count": 147,
461 "prompt_eval_duration": 684637583,
462 "eval_count": 16,
463 "eval_duration": 302561917,
464 });
465
466 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
467 match result.message {
468 ChatMessage::Assistant {
469 content,
470 tool_calls,
471 thinking,
472 } => {
473 assert!(content.is_empty());
474 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
475 assert!(thinking.is_none());
476 }
477 _ => panic!("Deserialized wrong role"),
478 }
479 }
480
481 #[test]
482 fn parse_show_model() {
483 let response = serde_json::json!({
484 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
485 "details": {
486 "parent_model": "",
487 "format": "gguf",
488 "family": "llama",
489 "families": ["llama"],
490 "parameter_size": "3.2B",
491 "quantization_level": "Q4_K_M"
492 },
493 "model_info": {
494 "general.architecture": "llama",
495 "general.basename": "Llama-3.2",
496 "general.file_type": 15,
497 "general.finetune": "Instruct",
498 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
499 "general.parameter_count": 3212749888u64,
500 "general.quantization_version": 2,
501 "general.size_label": "3B",
502 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
503 "general.type": "model",
504 "llama.attention.head_count": 24,
505 "llama.attention.head_count_kv": 8,
506 "llama.attention.key_length": 128,
507 "llama.attention.layer_norm_rms_epsilon": 0.00001,
508 "llama.attention.value_length": 128,
509 "llama.block_count": 28,
510 "llama.context_length": 131072,
511 "llama.embedding_length": 3072,
512 "llama.feed_forward_length": 8192,
513 "llama.rope.dimension_count": 128,
514 "llama.rope.freq_base": 500000,
515 "llama.vocab_size": 128256,
516 "tokenizer.ggml.bos_token_id": 128000,
517 "tokenizer.ggml.eos_token_id": 128009,
518 "tokenizer.ggml.merges": null,
519 "tokenizer.ggml.model": "gpt2",
520 "tokenizer.ggml.pre": "llama-bpe",
521 "tokenizer.ggml.token_type": null,
522 "tokenizer.ggml.tokens": null
523 },
524 "tensors": [
525 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
526 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
527 ],
528 "capabilities": ["completion", "tools"],
529 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
530 });
531
532 let result: ModelShow = serde_json::from_value(response).unwrap();
533 assert!(result.supports_tools());
534 assert!(result.capabilities.contains(&"tools".to_string()));
535 assert!(result.capabilities.contains(&"completion".to_string()));
536 }
537}