1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::time::Duration;
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(untagged)]
13pub enum KeepAlive {
14 /// Keep model alive for N seconds
15 Seconds(isize),
16 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
17 Duration(String),
18}
19
20impl KeepAlive {
21 /// Keep model alive until a new model is loaded or until Ollama shuts down
22 fn indefinite() -> Self {
23 Self::Seconds(-1)
24 }
25}
26
27impl Default for KeepAlive {
28 fn default() -> Self {
29 Self::indefinite()
30 }
31}
32
33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
35pub struct Model {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: u64,
39 pub keep_alive: Option<KeepAlive>,
40 pub supports_tools: Option<bool>,
41 pub supports_vision: Option<bool>,
42 pub supports_thinking: Option<bool>,
43}
44
45fn get_max_tokens(name: &str) -> u64 {
46 /// Default context length for unknown models.
47 const DEFAULT_TOKENS: u64 = 4096;
48 /// Magic number. Lets many Ollama models work with ~16GB of ram.
49 const MAXIMUM_TOKENS: u64 = 16384;
50
51 match name.split(':').next().unwrap() {
52 "phi" | "tinyllama" | "granite-code" => 2048,
53 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
54 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
55 "codellama" | "starcoder2" => 16384,
56 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
57 | "dolphin-mixtral" => 32768,
58 "magistral" => 40000,
59 "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
60 | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder"
61 | "devstral" | "gpt-oss" => 128000,
62 _ => DEFAULT_TOKENS,
63 }
64 .clamp(1, MAXIMUM_TOKENS)
65}
66
67impl Model {
68 pub fn new(
69 name: &str,
70 display_name: Option<&str>,
71 max_tokens: Option<u64>,
72 supports_tools: Option<bool>,
73 supports_vision: Option<bool>,
74 supports_thinking: Option<bool>,
75 ) -> Self {
76 Self {
77 name: name.to_owned(),
78 display_name: display_name
79 .map(ToString::to_string)
80 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
81 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
82 keep_alive: Some(KeepAlive::indefinite()),
83 supports_tools,
84 supports_vision,
85 supports_thinking,
86 }
87 }
88
89 pub fn id(&self) -> &str {
90 &self.name
91 }
92
93 pub fn display_name(&self) -> &str {
94 self.display_name.as_ref().unwrap_or(&self.name)
95 }
96
97 pub fn max_token_count(&self) -> u64 {
98 self.max_tokens
99 }
100}
101
102#[derive(Serialize, Deserialize, Debug)]
103#[serde(tag = "role", rename_all = "lowercase")]
104pub enum ChatMessage {
105 Assistant {
106 content: String,
107 tool_calls: Option<Vec<OllamaToolCall>>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 images: Option<Vec<String>>,
110 thinking: Option<String>,
111 },
112 User {
113 content: String,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 images: Option<Vec<String>>,
116 },
117 System {
118 content: String,
119 },
120}
121
122#[derive(Serialize, Deserialize, Debug)]
123#[serde(rename_all = "lowercase")]
124pub enum OllamaToolCall {
125 Function(OllamaFunctionCall),
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct OllamaFunctionCall {
130 pub name: String,
131 pub arguments: Value,
132}
133
134#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
135pub struct OllamaFunctionTool {
136 pub name: String,
137 pub description: Option<String>,
138 pub parameters: Option<Value>,
139}
140
141#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
142#[serde(tag = "type", rename_all = "lowercase")]
143pub enum OllamaTool {
144 Function { function: OllamaFunctionTool },
145}
146
147#[derive(Serialize, Debug)]
148pub struct ChatRequest {
149 pub model: String,
150 pub messages: Vec<ChatMessage>,
151 pub stream: bool,
152 pub keep_alive: KeepAlive,
153 pub options: Option<ChatOptions>,
154 pub tools: Vec<OllamaTool>,
155 pub think: Option<bool>,
156}
157
158impl ChatRequest {
159 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
160 self.stream = false;
161 self.tools = tools;
162 self
163 }
164}
165
166// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
167#[derive(Serialize, Default, Debug)]
168pub struct ChatOptions {
169 pub num_ctx: Option<u64>,
170 pub num_predict: Option<isize>,
171 pub stop: Option<Vec<String>>,
172 pub temperature: Option<f32>,
173 pub top_p: Option<f32>,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct ChatResponseDelta {
178 #[allow(unused)]
179 pub model: String,
180 #[allow(unused)]
181 pub created_at: String,
182 pub message: ChatMessage,
183 #[allow(unused)]
184 pub done_reason: Option<String>,
185 #[allow(unused)]
186 pub done: bool,
187 pub prompt_eval_count: Option<u64>,
188 pub eval_count: Option<u64>,
189}
190
191#[derive(Serialize, Deserialize)]
192pub struct LocalModelsResponse {
193 pub models: Vec<LocalModelListing>,
194}
195
196#[derive(Serialize, Deserialize)]
197pub struct LocalModelListing {
198 pub name: String,
199 pub modified_at: String,
200 pub size: u64,
201 pub digest: String,
202 pub details: ModelDetails,
203}
204
205#[derive(Serialize, Deserialize)]
206pub struct LocalModel {
207 pub modelfile: String,
208 pub parameters: String,
209 pub template: String,
210 pub details: ModelDetails,
211}
212
213#[derive(Serialize, Deserialize)]
214pub struct ModelDetails {
215 pub format: String,
216 pub family: String,
217 pub families: Option<Vec<String>>,
218 pub parameter_size: String,
219 pub quantization_level: String,
220}
221
222#[derive(Deserialize, Debug)]
223pub struct ModelShow {
224 #[serde(default)]
225 pub capabilities: Vec<String>,
226}
227
228impl ModelShow {
229 pub fn supports_tools(&self) -> bool {
230 // .contains expects &String, which would require an additional allocation
231 self.capabilities.iter().any(|v| v == "tools")
232 }
233
234 pub fn supports_vision(&self) -> bool {
235 self.capabilities.iter().any(|v| v == "vision")
236 }
237
238 pub fn supports_thinking(&self) -> bool {
239 self.capabilities.iter().any(|v| v == "thinking")
240 }
241}
242
243pub async fn complete(
244 client: &dyn HttpClient,
245 api_url: &str,
246 request: ChatRequest,
247) -> Result<ChatResponseDelta> {
248 let uri = format!("{api_url}/api/chat");
249 let request_builder = HttpRequest::builder()
250 .method(Method::POST)
251 .uri(uri)
252 .header("Content-Type", "application/json");
253
254 let serialized_request = serde_json::to_string(&request)?;
255 let request = request_builder.body(AsyncBody::from(serialized_request))?;
256
257 let mut response = client.send(request).await?;
258
259 let mut body = Vec::new();
260 response.body_mut().read_to_end(&mut body).await?;
261
262 if response.status().is_success() {
263 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
264 Ok(response_message)
265 } else {
266 let body_str = std::str::from_utf8(&body)?;
267 anyhow::bail!(
268 "Failed to connect to API: {} {}",
269 response.status(),
270 body_str
271 );
272 }
273}
274
275pub async fn stream_chat_completion(
276 client: &dyn HttpClient,
277 api_url: &str,
278 request: ChatRequest,
279) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
280 let uri = format!("{api_url}/api/chat");
281 let request_builder = http::Request::builder()
282 .method(Method::POST)
283 .uri(uri)
284 .header("Content-Type", "application/json");
285
286 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
287 let mut response = client.send(request).await?;
288 if response.status().is_success() {
289 let reader = BufReader::new(response.into_body());
290
291 Ok(reader
292 .lines()
293 .map(|line| match line {
294 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
295 Err(e) => Err(e.into()),
296 })
297 .boxed())
298 } else {
299 let mut body = String::new();
300 response.body_mut().read_to_string(&mut body).await?;
301 anyhow::bail!(
302 "Failed to connect to Ollama API: {} {}",
303 response.status(),
304 body,
305 );
306 }
307}
308
309pub async fn get_models(
310 client: &dyn HttpClient,
311 api_url: &str,
312 _: Option<Duration>,
313) -> Result<Vec<LocalModelListing>> {
314 let uri = format!("{api_url}/api/tags");
315 let request_builder = HttpRequest::builder()
316 .method(Method::GET)
317 .uri(uri)
318 .header("Accept", "application/json");
319
320 let request = request_builder.body(AsyncBody::default())?;
321
322 let mut response = client.send(request).await?;
323
324 let mut body = String::new();
325 response.body_mut().read_to_string(&mut body).await?;
326
327 anyhow::ensure!(
328 response.status().is_success(),
329 "Failed to connect to Ollama API: {} {}",
330 response.status(),
331 body,
332 );
333 let response: LocalModelsResponse =
334 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
335 Ok(response.models)
336}
337
338/// Fetch details of a model, used to determine model capabilities
339pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
340 let uri = format!("{api_url}/api/show");
341 let request = HttpRequest::builder()
342 .method(Method::POST)
343 .uri(uri)
344 .header("Content-Type", "application/json")
345 .body(AsyncBody::from(
346 serde_json::json!({ "model": model }).to_string(),
347 ))?;
348
349 let mut response = client.send(request).await?;
350 let mut body = String::new();
351 response.body_mut().read_to_string(&mut body).await?;
352
353 anyhow::ensure!(
354 response.status().is_success(),
355 "Failed to connect to Ollama API: {} {}",
356 response.status(),
357 body,
358 );
359 let details: ModelShow = serde_json::from_str(body.as_str())?;
360 Ok(details)
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn parse_completion() {
369 let response = serde_json::json!({
370 "model": "llama3.2",
371 "created_at": "2023-12-12T14:13:43.416799Z",
372 "message": {
373 "role": "assistant",
374 "content": "Hello! How are you today?"
375 },
376 "done": true,
377 "total_duration": 5191566416u64,
378 "load_duration": 2154458,
379 "prompt_eval_count": 26,
380 "prompt_eval_duration": 383809000,
381 "eval_count": 298,
382 "eval_duration": 4799921000u64
383 });
384 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
385 }
386
387 #[test]
388 fn parse_streaming_completion() {
389 let partial = serde_json::json!({
390 "model": "llama3.2",
391 "created_at": "2023-08-04T08:52:19.385406455-07:00",
392 "message": {
393 "role": "assistant",
394 "content": "The",
395 "images": null
396 },
397 "done": false
398 });
399
400 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
401
402 let last = serde_json::json!({
403 "model": "llama3.2",
404 "created_at": "2023-08-04T19:22:45.499127Z",
405 "message": {
406 "role": "assistant",
407 "content": ""
408 },
409 "done": true,
410 "total_duration": 4883583458u64,
411 "load_duration": 1334875,
412 "prompt_eval_count": 26,
413 "prompt_eval_duration": 342546000,
414 "eval_count": 282,
415 "eval_duration": 4535599000u64
416 });
417
418 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
419 }
420
421 #[test]
422 fn parse_tool_call() {
423 let response = serde_json::json!({
424 "model": "llama3.2:3b",
425 "created_at": "2025-04-28T20:02:02.140489Z",
426 "message": {
427 "role": "assistant",
428 "content": "",
429 "tool_calls": [
430 {
431 "function": {
432 "name": "weather",
433 "arguments": {
434 "city": "london",
435 }
436 }
437 }
438 ]
439 },
440 "done_reason": "stop",
441 "done": true,
442 "total_duration": 2758629166u64,
443 "load_duration": 1770059875,
444 "prompt_eval_count": 147,
445 "prompt_eval_duration": 684637583,
446 "eval_count": 16,
447 "eval_duration": 302561917,
448 });
449
450 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
451 match result.message {
452 ChatMessage::Assistant {
453 content,
454 tool_calls,
455 images: _,
456 thinking,
457 } => {
458 assert!(content.is_empty());
459 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
460 assert!(thinking.is_none());
461 }
462 _ => panic!("Deserialized wrong role"),
463 }
464 }
465
466 #[test]
467 fn parse_show_model() {
468 let response = serde_json::json!({
469 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
470 "details": {
471 "parent_model": "",
472 "format": "gguf",
473 "family": "llama",
474 "families": ["llama"],
475 "parameter_size": "3.2B",
476 "quantization_level": "Q4_K_M"
477 },
478 "model_info": {
479 "general.architecture": "llama",
480 "general.basename": "Llama-3.2",
481 "general.file_type": 15,
482 "general.finetune": "Instruct",
483 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
484 "general.parameter_count": 3212749888u64,
485 "general.quantization_version": 2,
486 "general.size_label": "3B",
487 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
488 "general.type": "model",
489 "llama.attention.head_count": 24,
490 "llama.attention.head_count_kv": 8,
491 "llama.attention.key_length": 128,
492 "llama.attention.layer_norm_rms_epsilon": 0.00001,
493 "llama.attention.value_length": 128,
494 "llama.block_count": 28,
495 "llama.context_length": 131072,
496 "llama.embedding_length": 3072,
497 "llama.feed_forward_length": 8192,
498 "llama.rope.dimension_count": 128,
499 "llama.rope.freq_base": 500000,
500 "llama.vocab_size": 128256,
501 "tokenizer.ggml.bos_token_id": 128000,
502 "tokenizer.ggml.eos_token_id": 128009,
503 "tokenizer.ggml.merges": null,
504 "tokenizer.ggml.model": "gpt2",
505 "tokenizer.ggml.pre": "llama-bpe",
506 "tokenizer.ggml.token_type": null,
507 "tokenizer.ggml.tokens": null
508 },
509 "tensors": [
510 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
511 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
512 ],
513 "capabilities": ["completion", "tools"],
514 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
515 });
516
517 let result: ModelShow = serde_json::from_value(response).unwrap();
518 assert!(result.supports_tools());
519 assert!(result.capabilities.contains(&"tools".to_string()));
520 assert!(result.capabilities.contains(&"completion".to_string()));
521 }
522
523 #[test]
524 fn serialize_chat_request_with_images() {
525 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
526
527 let request = ChatRequest {
528 model: "llava".to_string(),
529 messages: vec![ChatMessage::User {
530 content: "What do you see in this image?".to_string(),
531 images: Some(vec![base64_image.to_string()]),
532 }],
533 stream: false,
534 keep_alive: KeepAlive::default(),
535 options: None,
536 think: None,
537 tools: vec![],
538 };
539
540 let serialized = serde_json::to_string(&request).unwrap();
541 assert!(serialized.contains("images"));
542 assert!(serialized.contains(base64_image));
543 }
544
545 #[test]
546 fn serialize_chat_request_without_images() {
547 let request = ChatRequest {
548 model: "llama3.2".to_string(),
549 messages: vec![ChatMessage::User {
550 content: "Hello, world!".to_string(),
551 images: None,
552 }],
553 stream: false,
554 keep_alive: KeepAlive::default(),
555 options: None,
556 think: None,
557 tools: vec![],
558 };
559
560 let serialized = serde_json::to_string(&request).unwrap();
561 assert!(!serialized.contains("images"));
562 }
563
564 #[test]
565 fn test_json_format_with_images() {
566 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
567
568 let request = ChatRequest {
569 model: "llava".to_string(),
570 messages: vec![ChatMessage::User {
571 content: "What do you see?".to_string(),
572 images: Some(vec![base64_image.to_string()]),
573 }],
574 stream: false,
575 keep_alive: KeepAlive::default(),
576 options: None,
577 think: None,
578 tools: vec![],
579 };
580
581 let serialized = serde_json::to_string(&request).unwrap();
582
583 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
584 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
585 assert_eq!(message_images.len(), 1);
586 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
587 }
588}