1mod ollama_edit_prediction_delegate;
2
3pub use ollama_edit_prediction_delegate::OllamaEditPredictionDelegate;
4
5use anyhow::{Context, Result};
6
7use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
8use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11pub use settings::KeepAlive;
12
13pub const OLLAMA_API_URL: &str = "http://localhost:11434";
14
15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
17pub struct Model {
18 pub name: String,
19 pub display_name: Option<String>,
20 pub max_tokens: u64,
21 pub keep_alive: Option<KeepAlive>,
22 pub supports_tools: Option<bool>,
23 pub supports_vision: Option<bool>,
24 pub supports_thinking: Option<bool>,
25}
26
27fn get_max_tokens(name: &str) -> u64 {
28 /// Default context length for unknown models.
29 const DEFAULT_TOKENS: u64 = 4096;
30 /// Magic number. Lets many Ollama models work with ~16GB of ram.
31 /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
32 const MAXIMUM_TOKENS: u64 = 16384;
33
34 match name.split(':').next().unwrap() {
35 "granite-code" | "phi" | "tinyllama" => 2048,
36 "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
37 "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
38 "codellama" | "starcoder2" => 16384,
39 "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
40 | "qwen2" | "qwen2.5-coder" => 32768,
41 "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
42 | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
43 | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
44 "qwen3-coder" => 256000,
45 _ => DEFAULT_TOKENS,
46 }
47 .clamp(1, MAXIMUM_TOKENS)
48}
49
50impl Model {
51 pub fn new(
52 name: &str,
53 display_name: Option<&str>,
54 max_tokens: Option<u64>,
55 supports_tools: Option<bool>,
56 supports_vision: Option<bool>,
57 supports_thinking: Option<bool>,
58 ) -> Self {
59 Self {
60 name: name.to_owned(),
61 display_name: display_name
62 .map(ToString::to_string)
63 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
64 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
65 keep_alive: Some(KeepAlive::indefinite()),
66 supports_tools,
67 supports_vision,
68 supports_thinking,
69 }
70 }
71
72 pub fn id(&self) -> &str {
73 &self.name
74 }
75
76 pub fn display_name(&self) -> &str {
77 self.display_name.as_ref().unwrap_or(&self.name)
78 }
79
80 pub fn max_token_count(&self) -> u64 {
81 self.max_tokens
82 }
83}
84
85#[derive(Serialize, Deserialize, Debug)]
86#[serde(tag = "role", rename_all = "lowercase")]
87pub enum ChatMessage {
88 Assistant {
89 content: String,
90 tool_calls: Option<Vec<OllamaToolCall>>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 images: Option<Vec<String>>,
93 thinking: Option<String>,
94 },
95 User {
96 content: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 images: Option<Vec<String>>,
99 },
100 System {
101 content: String,
102 },
103 Tool {
104 tool_name: String,
105 content: String,
106 },
107}
108
109#[derive(Serialize, Deserialize, Debug)]
110pub struct OllamaToolCall {
111 // TODO: Remove `Option` after most users have updated to Ollama v0.12.10,
112 // which was released on the 4th of November 2025
113 pub id: Option<String>,
114 pub function: OllamaFunctionCall,
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118pub struct OllamaFunctionCall {
119 pub name: String,
120 pub arguments: Value,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124pub struct OllamaFunctionTool {
125 pub name: String,
126 pub description: Option<String>,
127 pub parameters: Option<Value>,
128}
129
130#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
131#[serde(tag = "type", rename_all = "lowercase")]
132pub enum OllamaTool {
133 Function { function: OllamaFunctionTool },
134}
135
136#[derive(Serialize, Debug)]
137pub struct ChatRequest {
138 pub model: String,
139 pub messages: Vec<ChatMessage>,
140 pub stream: bool,
141 pub keep_alive: KeepAlive,
142 pub options: Option<ChatOptions>,
143 pub tools: Vec<OllamaTool>,
144 pub think: Option<bool>,
145}
146
147// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
148#[derive(Serialize, Default, Debug)]
149pub struct ChatOptions {
150 pub num_ctx: Option<u64>,
151 pub num_predict: Option<isize>,
152 pub stop: Option<Vec<String>>,
153 pub temperature: Option<f32>,
154 pub top_p: Option<f32>,
155}
156
157#[derive(Deserialize, Debug)]
158pub struct ChatResponseDelta {
159 pub model: String,
160 pub created_at: String,
161 pub message: ChatMessage,
162 pub done_reason: Option<String>,
163 pub done: bool,
164 pub prompt_eval_count: Option<u64>,
165 pub eval_count: Option<u64>,
166}
167
168#[derive(Serialize, Deserialize)]
169pub struct LocalModelsResponse {
170 pub models: Vec<LocalModelListing>,
171}
172
173#[derive(Serialize, Deserialize)]
174pub struct LocalModelListing {
175 pub name: String,
176 pub modified_at: String,
177 pub size: u64,
178 pub digest: String,
179 pub details: ModelDetails,
180}
181
182#[derive(Serialize, Deserialize)]
183pub struct LocalModel {
184 pub modelfile: String,
185 pub parameters: String,
186 pub template: String,
187 pub details: ModelDetails,
188}
189
190#[derive(Serialize, Deserialize)]
191pub struct ModelDetails {
192 pub format: String,
193 pub family: String,
194 pub families: Option<Vec<String>>,
195 pub parameter_size: String,
196 pub quantization_level: String,
197}
198
199#[derive(Debug)]
200pub struct ModelShow {
201 pub capabilities: Vec<String>,
202 pub context_length: Option<u64>,
203 pub architecture: Option<String>,
204}
205
206impl<'de> Deserialize<'de> for ModelShow {
207 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
208 where
209 D: serde::Deserializer<'de>,
210 {
211 use serde::de::{self, MapAccess, Visitor};
212 use std::fmt;
213
214 struct ModelShowVisitor;
215
216 impl<'de> Visitor<'de> for ModelShowVisitor {
217 type Value = ModelShow;
218
219 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
220 formatter.write_str("a ModelShow object")
221 }
222
223 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
224 where
225 A: MapAccess<'de>,
226 {
227 let mut capabilities: Vec<String> = Vec::new();
228 let mut architecture: Option<String> = None;
229 let mut context_length: Option<u64> = None;
230
231 while let Some(key) = map.next_key::<String>()? {
232 match key.as_str() {
233 "capabilities" => {
234 capabilities = map.next_value()?;
235 }
236 "model_info" => {
237 let model_info: Value = map.next_value()?;
238 if let Value::Object(obj) = model_info {
239 architecture = obj
240 .get("general.architecture")
241 .and_then(|v| v.as_str())
242 .map(String::from);
243
244 if let Some(arch) = &architecture {
245 context_length = obj
246 .get(&format!("{}.context_length", arch))
247 .and_then(|v| v.as_u64());
248 }
249 }
250 }
251 _ => {
252 let _: de::IgnoredAny = map.next_value()?;
253 }
254 }
255 }
256
257 Ok(ModelShow {
258 capabilities,
259 context_length,
260 architecture,
261 })
262 }
263 }
264
265 deserializer.deserialize_map(ModelShowVisitor)
266 }
267}
268
269impl ModelShow {
270 pub fn supports_tools(&self) -> bool {
271 // .contains expects &String, which would require an additional allocation
272 self.capabilities.iter().any(|v| v == "tools")
273 }
274
275 pub fn supports_vision(&self) -> bool {
276 self.capabilities.iter().any(|v| v == "vision")
277 }
278
279 pub fn supports_thinking(&self) -> bool {
280 self.capabilities.iter().any(|v| v == "thinking")
281 }
282}
283
284pub async fn stream_chat_completion(
285 client: &dyn HttpClient,
286 api_url: &str,
287 api_key: Option<&str>,
288 request: ChatRequest,
289) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
290 let uri = format!("{api_url}/api/chat");
291 let request = HttpRequest::builder()
292 .method(Method::POST)
293 .uri(uri)
294 .header("Content-Type", "application/json")
295 .when_some(api_key, |builder, api_key| {
296 builder.header("Authorization", format!("Bearer {api_key}"))
297 })
298 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
299
300 let mut response = client.send(request).await?;
301 if response.status().is_success() {
302 let reader = BufReader::new(response.into_body());
303
304 Ok(reader
305 .lines()
306 .map(|line| match line {
307 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
308 Err(e) => Err(e.into()),
309 })
310 .boxed())
311 } else {
312 let mut body = String::new();
313 response.body_mut().read_to_string(&mut body).await?;
314 anyhow::bail!(
315 "Failed to connect to Ollama API: {} {}",
316 response.status(),
317 body,
318 );
319 }
320}
321
322pub async fn get_models(
323 client: &dyn HttpClient,
324 api_url: &str,
325 api_key: Option<&str>,
326) -> Result<Vec<LocalModelListing>> {
327 let uri = format!("{api_url}/api/tags");
328 let request = HttpRequest::builder()
329 .method(Method::GET)
330 .uri(uri)
331 .header("Accept", "application/json")
332 .when_some(api_key, |builder, api_key| {
333 builder.header("Authorization", format!("Bearer {api_key}"))
334 })
335 .body(AsyncBody::default())?;
336
337 let mut response = client.send(request).await?;
338
339 let mut body = String::new();
340 response.body_mut().read_to_string(&mut body).await?;
341
342 anyhow::ensure!(
343 response.status().is_success(),
344 "Failed to connect to Ollama API: {} {}",
345 response.status(),
346 body,
347 );
348 let response: LocalModelsResponse =
349 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
350 Ok(response.models)
351}
352
353/// Fetch details of a model, used to determine model capabilities
354pub async fn show_model(
355 client: &dyn HttpClient,
356 api_url: &str,
357 api_key: Option<&str>,
358 model: &str,
359) -> Result<ModelShow> {
360 let uri = format!("{api_url}/api/show");
361 let request = HttpRequest::builder()
362 .method(Method::POST)
363 .uri(uri)
364 .header("Content-Type", "application/json")
365 .when_some(api_key, |builder, api_key| {
366 builder.header("Authorization", format!("Bearer {api_key}"))
367 })
368 .body(AsyncBody::from(
369 serde_json::json!({ "model": model }).to_string(),
370 ))?;
371
372 let mut response = client.send(request).await?;
373 let mut body = String::new();
374 response.body_mut().read_to_string(&mut body).await?;
375
376 anyhow::ensure!(
377 response.status().is_success(),
378 "Failed to connect to Ollama API: {} {}",
379 response.status(),
380 body,
381 );
382 let details: ModelShow = serde_json::from_str(body.as_str())?;
383 Ok(details)
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn parse_completion() {
392 let response = serde_json::json!({
393 "model": "llama3.2",
394 "created_at": "2023-12-12T14:13:43.416799Z",
395 "message": {
396 "role": "assistant",
397 "content": "Hello! How are you today?"
398 },
399 "done": true,
400 "total_duration": 5191566416u64,
401 "load_duration": 2154458,
402 "prompt_eval_count": 26,
403 "prompt_eval_duration": 383809000,
404 "eval_count": 298,
405 "eval_duration": 4799921000u64
406 });
407 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
408 }
409
410 #[test]
411 fn parse_streaming_completion() {
412 let partial = serde_json::json!({
413 "model": "llama3.2",
414 "created_at": "2023-08-04T08:52:19.385406455-07:00",
415 "message": {
416 "role": "assistant",
417 "content": "The",
418 "images": null
419 },
420 "done": false
421 });
422
423 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
424
425 let last = serde_json::json!({
426 "model": "llama3.2",
427 "created_at": "2023-08-04T19:22:45.499127Z",
428 "message": {
429 "role": "assistant",
430 "content": ""
431 },
432 "done": true,
433 "total_duration": 4883583458u64,
434 "load_duration": 1334875,
435 "prompt_eval_count": 26,
436 "prompt_eval_duration": 342546000,
437 "eval_count": 282,
438 "eval_duration": 4535599000u64
439 });
440
441 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
442 }
443
444 #[test]
445 fn parse_tool_call() {
446 let response = serde_json::json!({
447 "model": "llama3.2:3b",
448 "created_at": "2025-04-28T20:02:02.140489Z",
449 "message": {
450 "role": "assistant",
451 "content": "",
452 "tool_calls": [
453 {
454 "id": "call_llama3.2:3b_145155",
455 "function": {
456 "name": "weather",
457 "arguments": {
458 "city": "london",
459 }
460 }
461 }
462 ]
463 },
464 "done_reason": "stop",
465 "done": true,
466 "total_duration": 2758629166u64,
467 "load_duration": 1770059875,
468 "prompt_eval_count": 147,
469 "prompt_eval_duration": 684637583,
470 "eval_count": 16,
471 "eval_duration": 302561917,
472 });
473
474 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
475 match result.message {
476 ChatMessage::Assistant {
477 content,
478 tool_calls,
479 images: _,
480 thinking,
481 } => {
482 assert!(content.is_empty());
483 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
484 assert!(thinking.is_none());
485 }
486 _ => panic!("Deserialized wrong role"),
487 }
488 }
489
490 // Backwards compatibility with Ollama versions prior to v0.12.10 November 2025
491 // This test is a copy of `parse_tool_call()` with the `id` field omitted.
492 #[test]
493 fn parse_tool_call_pre_0_12_10() {
494 let response = serde_json::json!({
495 "model": "llama3.2:3b",
496 "created_at": "2025-04-28T20:02:02.140489Z",
497 "message": {
498 "role": "assistant",
499 "content": "",
500 "tool_calls": [
501 {
502 "function": {
503 "name": "weather",
504 "arguments": {
505 "city": "london",
506 }
507 }
508 }
509 ]
510 },
511 "done_reason": "stop",
512 "done": true,
513 "total_duration": 2758629166u64,
514 "load_duration": 1770059875,
515 "prompt_eval_count": 147,
516 "prompt_eval_duration": 684637583,
517 "eval_count": 16,
518 "eval_duration": 302561917,
519 });
520
521 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
522 match result.message {
523 ChatMessage::Assistant {
524 content,
525 tool_calls: Some(tool_calls),
526 images: _,
527 thinking,
528 } => {
529 assert!(content.is_empty());
530 assert!(thinking.is_none());
531
532 // When the `Option` around `id` is removed, this test should complain
533 // and be subsequently deleted in favor of `parse_tool_call()`
534 assert!(tool_calls.first().is_some_and(|call| call.id.is_none()))
535 }
536 _ => panic!("Deserialized wrong role"),
537 }
538 }
539
540 #[test]
541 fn parse_show_model() {
542 let response = serde_json::json!({
543 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
544 "details": {
545 "parent_model": "",
546 "format": "gguf",
547 "family": "llama",
548 "families": ["llama"],
549 "parameter_size": "3.2B",
550 "quantization_level": "Q4_K_M"
551 },
552 "model_info": {
553 "general.architecture": "llama",
554 "general.basename": "Llama-3.2",
555 "general.file_type": 15,
556 "general.finetune": "Instruct",
557 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
558 "general.parameter_count": 3212749888u64,
559 "general.quantization_version": 2,
560 "general.size_label": "3B",
561 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
562 "general.type": "model",
563 "llama.attention.head_count": 24,
564 "llama.attention.head_count_kv": 8,
565 "llama.attention.key_length": 128,
566 "llama.attention.layer_norm_rms_epsilon": 0.00001,
567 "llama.attention.value_length": 128,
568 "llama.block_count": 28,
569 "llama.context_length": 131072,
570 "llama.embedding_length": 3072,
571 "llama.feed_forward_length": 8192,
572 "llama.rope.dimension_count": 128,
573 "llama.rope.freq_base": 500000,
574 "llama.vocab_size": 128256,
575 "tokenizer.ggml.bos_token_id": 128000,
576 "tokenizer.ggml.eos_token_id": 128009,
577 "tokenizer.ggml.merges": null,
578 "tokenizer.ggml.model": "gpt2",
579 "tokenizer.ggml.pre": "llama-bpe",
580 "tokenizer.ggml.token_type": null,
581 "tokenizer.ggml.tokens": null
582 },
583 "tensors": [
584 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
585 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
586 ],
587 "capabilities": ["completion", "tools"],
588 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
589 });
590
591 let result: ModelShow = serde_json::from_value(response).unwrap();
592 assert!(result.supports_tools());
593 assert!(result.capabilities.contains(&"tools".to_string()));
594 assert!(result.capabilities.contains(&"completion".to_string()));
595
596 assert_eq!(result.architecture, Some("llama".to_string()));
597 assert_eq!(result.context_length, Some(131072));
598 }
599
600 #[test]
601 fn serialize_chat_request_with_images() {
602 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
603
604 let request = ChatRequest {
605 model: "llava".to_string(),
606 messages: vec![ChatMessage::User {
607 content: "What do you see in this image?".to_string(),
608 images: Some(vec![base64_image.to_string()]),
609 }],
610 stream: false,
611 keep_alive: KeepAlive::default(),
612 options: None,
613 think: None,
614 tools: vec![],
615 };
616
617 let serialized = serde_json::to_string(&request).unwrap();
618 assert!(serialized.contains("images"));
619 assert!(serialized.contains(base64_image));
620 }
621
622 #[test]
623 fn serialize_chat_request_without_images() {
624 let request = ChatRequest {
625 model: "llama3.2".to_string(),
626 messages: vec![ChatMessage::User {
627 content: "Hello, world!".to_string(),
628 images: None,
629 }],
630 stream: false,
631 keep_alive: KeepAlive::default(),
632 options: None,
633 think: None,
634 tools: vec![],
635 };
636
637 let serialized = serde_json::to_string(&request).unwrap();
638 assert!(!serialized.contains("images"));
639 }
640
641 #[test]
642 fn test_json_format_with_images() {
643 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
644
645 let request = ChatRequest {
646 model: "llava".to_string(),
647 messages: vec![ChatMessage::User {
648 content: "What do you see?".to_string(),
649 images: Some(vec![base64_image.to_string()]),
650 }],
651 stream: false,
652 keep_alive: KeepAlive::default(),
653 options: None,
654 think: None,
655 tools: vec![],
656 };
657
658 let serialized = serde_json::to_string(&request).unwrap();
659
660 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
661 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
662 assert_eq!(message_images.len(), 1);
663 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
664 }
665}