1use anyhow::{Context, Result};
2
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7pub use settings::KeepAlive;
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
13pub struct Model {
14 pub name: String,
15 pub display_name: Option<String>,
16 pub max_tokens: u64,
17 pub keep_alive: Option<KeepAlive>,
18 pub supports_tools: Option<bool>,
19 pub supports_vision: Option<bool>,
20 pub supports_thinking: Option<bool>,
21}
22
23fn get_max_tokens(_name: &str) -> u64 {
24 const DEFAULT_TOKENS: u64 = 4096;
25 DEFAULT_TOKENS
26}
27
28impl Model {
29 pub fn new(
30 name: &str,
31 display_name: Option<&str>,
32 max_tokens: Option<u64>,
33 supports_tools: Option<bool>,
34 supports_vision: Option<bool>,
35 supports_thinking: Option<bool>,
36 ) -> Self {
37 Self {
38 name: name.to_owned(),
39 display_name: display_name
40 .map(ToString::to_string)
41 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
42 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
43 keep_alive: Some(KeepAlive::indefinite()),
44 supports_tools,
45 supports_vision,
46 supports_thinking,
47 }
48 }
49
50 pub fn id(&self) -> &str {
51 &self.name
52 }
53
54 pub fn display_name(&self) -> &str {
55 self.display_name.as_ref().unwrap_or(&self.name)
56 }
57
58 pub fn max_token_count(&self) -> u64 {
59 self.max_tokens
60 }
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64#[serde(tag = "role", rename_all = "lowercase")]
65pub enum ChatMessage {
66 Assistant {
67 content: String,
68 tool_calls: Option<Vec<OllamaToolCall>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 images: Option<Vec<String>>,
71 thinking: Option<String>,
72 },
73 User {
74 content: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 images: Option<Vec<String>>,
77 },
78 System {
79 content: String,
80 },
81 Tool {
82 tool_name: String,
83 content: String,
84 },
85}
86
87#[derive(Serialize, Deserialize, Debug)]
88pub struct OllamaToolCall {
89 // TODO: Remove `Option` after most users have updated to Ollama v0.12.10,
90 // which was released on the 4th of November 2025
91 pub id: Option<String>,
92 pub function: OllamaFunctionCall,
93}
94
95#[derive(Serialize, Deserialize, Debug)]
96pub struct OllamaFunctionCall {
97 pub name: String,
98 pub arguments: Value,
99}
100
101#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
102pub struct OllamaFunctionTool {
103 pub name: String,
104 pub description: Option<String>,
105 pub parameters: Option<Value>,
106}
107
108#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
109#[serde(tag = "type", rename_all = "lowercase")]
110pub enum OllamaTool {
111 Function { function: OllamaFunctionTool },
112}
113
114#[derive(Serialize, Debug)]
115pub struct ChatRequest {
116 pub model: String,
117 pub messages: Vec<ChatMessage>,
118 pub stream: bool,
119 pub keep_alive: KeepAlive,
120 pub options: Option<ChatOptions>,
121 pub tools: Vec<OllamaTool>,
122 pub think: Option<bool>,
123}
124
125// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
126#[derive(Serialize, Default, Debug)]
127pub struct ChatOptions {
128 pub num_ctx: Option<u64>,
129 pub num_predict: Option<isize>,
130 pub stop: Option<Vec<String>>,
131 pub temperature: Option<f32>,
132 pub top_p: Option<f32>,
133}
134
135#[derive(Deserialize, Debug)]
136pub struct ChatResponseDelta {
137 pub model: String,
138 pub created_at: String,
139 pub message: ChatMessage,
140 pub done_reason: Option<String>,
141 pub done: bool,
142 pub prompt_eval_count: Option<u64>,
143 pub eval_count: Option<u64>,
144}
145
146#[derive(Serialize, Deserialize)]
147pub struct LocalModelsResponse {
148 pub models: Vec<LocalModelListing>,
149}
150
151#[derive(Serialize, Deserialize)]
152pub struct LocalModelListing {
153 pub name: String,
154 pub modified_at: String,
155 pub size: u64,
156 pub digest: String,
157 pub details: ModelDetails,
158}
159
160#[derive(Serialize, Deserialize)]
161pub struct LocalModel {
162 pub modelfile: String,
163 pub parameters: String,
164 pub template: String,
165 pub details: ModelDetails,
166}
167
168#[derive(Serialize, Deserialize)]
169pub struct ModelDetails {
170 pub format: String,
171 pub family: String,
172 pub families: Option<Vec<String>>,
173 pub parameter_size: String,
174 pub quantization_level: String,
175}
176
177#[derive(Debug)]
178pub struct ModelShow {
179 pub capabilities: Vec<String>,
180 pub context_length: Option<u64>,
181 pub architecture: Option<String>,
182}
183
184impl<'de> Deserialize<'de> for ModelShow {
185 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186 where
187 D: serde::Deserializer<'de>,
188 {
189 use serde::de::{self, MapAccess, Visitor};
190 use std::fmt;
191
192 struct ModelShowVisitor;
193
194 impl<'de> Visitor<'de> for ModelShowVisitor {
195 type Value = ModelShow;
196
197 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
198 formatter.write_str("a ModelShow object")
199 }
200
201 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
202 where
203 A: MapAccess<'de>,
204 {
205 let mut capabilities: Vec<String> = Vec::new();
206 let mut architecture: Option<String> = None;
207 let mut context_length: Option<u64> = None;
208
209 while let Some(key) = map.next_key::<String>()? {
210 match key.as_str() {
211 "capabilities" => {
212 capabilities = map.next_value()?;
213 }
214 "model_info" => {
215 let model_info: Value = map.next_value()?;
216 if let Value::Object(obj) = model_info {
217 architecture = obj
218 .get("general.architecture")
219 .and_then(|v| v.as_str())
220 .map(String::from);
221
222 if let Some(arch) = &architecture {
223 context_length = obj
224 .get(&format!("{}.context_length", arch))
225 .and_then(|v| v.as_u64());
226 }
227 }
228 }
229 _ => {
230 let _: de::IgnoredAny = map.next_value()?;
231 }
232 }
233 }
234
235 Ok(ModelShow {
236 capabilities,
237 context_length,
238 architecture,
239 })
240 }
241 }
242
243 deserializer.deserialize_map(ModelShowVisitor)
244 }
245}
246
247impl ModelShow {
248 pub fn supports_tools(&self) -> bool {
249 // .contains expects &String, which would require an additional allocation
250 self.capabilities.iter().any(|v| v == "tools")
251 }
252
253 pub fn supports_vision(&self) -> bool {
254 self.capabilities.iter().any(|v| v == "vision")
255 }
256
257 pub fn supports_thinking(&self) -> bool {
258 self.capabilities.iter().any(|v| v == "thinking")
259 }
260}
261
262pub async fn stream_chat_completion(
263 client: &dyn HttpClient,
264 api_url: &str,
265 api_key: Option<&str>,
266 request: ChatRequest,
267) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
268 let uri = format!("{api_url}/api/chat");
269 let request = HttpRequest::builder()
270 .method(Method::POST)
271 .uri(uri)
272 .header("Content-Type", "application/json")
273 .when_some(api_key, |builder, api_key| {
274 builder.header("Authorization", format!("Bearer {api_key}"))
275 })
276 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
277
278 let mut response = client.send(request).await?;
279 if response.status().is_success() {
280 let reader = BufReader::new(response.into_body());
281
282 Ok(reader
283 .lines()
284 .map(|line| match line {
285 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
286 Err(e) => Err(e.into()),
287 })
288 .boxed())
289 } else {
290 let mut body = String::new();
291 response.body_mut().read_to_string(&mut body).await?;
292 anyhow::bail!(
293 "Failed to connect to Ollama API: {} {}",
294 response.status(),
295 body,
296 );
297 }
298}
299
300pub async fn get_models(
301 client: &dyn HttpClient,
302 api_url: &str,
303 api_key: Option<&str>,
304) -> Result<Vec<LocalModelListing>> {
305 let uri = format!("{api_url}/api/tags");
306 let request = HttpRequest::builder()
307 .method(Method::GET)
308 .uri(uri)
309 .header("Accept", "application/json")
310 .when_some(api_key, |builder, api_key| {
311 builder.header("Authorization", format!("Bearer {api_key}"))
312 })
313 .body(AsyncBody::default())?;
314
315 let mut response = client.send(request).await?;
316
317 let mut body = String::new();
318 response.body_mut().read_to_string(&mut body).await?;
319
320 anyhow::ensure!(
321 response.status().is_success(),
322 "Failed to connect to Ollama API: {} {}",
323 response.status(),
324 body,
325 );
326 let response: LocalModelsResponse =
327 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
328 Ok(response.models)
329}
330
331/// Fetch details of a model, used to determine model capabilities
332pub async fn show_model(
333 client: &dyn HttpClient,
334 api_url: &str,
335 api_key: Option<&str>,
336 model: &str,
337) -> Result<ModelShow> {
338 let uri = format!("{api_url}/api/show");
339 let request = HttpRequest::builder()
340 .method(Method::POST)
341 .uri(uri)
342 .header("Content-Type", "application/json")
343 .when_some(api_key, |builder, api_key| {
344 builder.header("Authorization", format!("Bearer {api_key}"))
345 })
346 .body(AsyncBody::from(
347 serde_json::json!({ "model": model }).to_string(),
348 ))?;
349
350 let mut response = client.send(request).await?;
351 let mut body = String::new();
352 response.body_mut().read_to_string(&mut body).await?;
353
354 anyhow::ensure!(
355 response.status().is_success(),
356 "Failed to connect to Ollama API: {} {}",
357 response.status(),
358 body,
359 );
360 let details: ModelShow = serde_json::from_str(body.as_str())?;
361 Ok(details)
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn parse_completion() {
370 let response = serde_json::json!({
371 "model": "llama3.2",
372 "created_at": "2023-12-12T14:13:43.416799Z",
373 "message": {
374 "role": "assistant",
375 "content": "Hello! How are you today?"
376 },
377 "done": true,
378 "total_duration": 5191566416u64,
379 "load_duration": 2154458,
380 "prompt_eval_count": 26,
381 "prompt_eval_duration": 383809000,
382 "eval_count": 298,
383 "eval_duration": 4799921000u64
384 });
385 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
386 }
387
388 #[test]
389 fn parse_streaming_completion() {
390 let partial = serde_json::json!({
391 "model": "llama3.2",
392 "created_at": "2023-08-04T08:52:19.385406455-07:00",
393 "message": {
394 "role": "assistant",
395 "content": "The",
396 "images": null
397 },
398 "done": false
399 });
400
401 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
402
403 let last = serde_json::json!({
404 "model": "llama3.2",
405 "created_at": "2023-08-04T19:22:45.499127Z",
406 "message": {
407 "role": "assistant",
408 "content": ""
409 },
410 "done": true,
411 "total_duration": 4883583458u64,
412 "load_duration": 1334875,
413 "prompt_eval_count": 26,
414 "prompt_eval_duration": 342546000,
415 "eval_count": 282,
416 "eval_duration": 4535599000u64
417 });
418
419 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
420 }
421
422 #[test]
423 fn parse_tool_call() {
424 let response = serde_json::json!({
425 "model": "llama3.2:3b",
426 "created_at": "2025-04-28T20:02:02.140489Z",
427 "message": {
428 "role": "assistant",
429 "content": "",
430 "tool_calls": [
431 {
432 "id": "call_llama3.2:3b_145155",
433 "function": {
434 "name": "weather",
435 "arguments": {
436 "city": "london",
437 }
438 }
439 }
440 ]
441 },
442 "done_reason": "stop",
443 "done": true,
444 "total_duration": 2758629166u64,
445 "load_duration": 1770059875,
446 "prompt_eval_count": 147,
447 "prompt_eval_duration": 684637583,
448 "eval_count": 16,
449 "eval_duration": 302561917,
450 });
451
452 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
453 match result.message {
454 ChatMessage::Assistant {
455 content,
456 tool_calls,
457 images: _,
458 thinking,
459 } => {
460 assert!(content.is_empty());
461 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
462 assert!(thinking.is_none());
463 }
464 _ => panic!("Deserialized wrong role"),
465 }
466 }
467
468 // Backwards compatibility with Ollama versions prior to v0.12.10 November 2025
469 // This test is a copy of `parse_tool_call()` with the `id` field omitted.
470 #[test]
471 fn parse_tool_call_pre_0_12_10() {
472 let response = serde_json::json!({
473 "model": "llama3.2:3b",
474 "created_at": "2025-04-28T20:02:02.140489Z",
475 "message": {
476 "role": "assistant",
477 "content": "",
478 "tool_calls": [
479 {
480 "function": {
481 "name": "weather",
482 "arguments": {
483 "city": "london",
484 }
485 }
486 }
487 ]
488 },
489 "done_reason": "stop",
490 "done": true,
491 "total_duration": 2758629166u64,
492 "load_duration": 1770059875,
493 "prompt_eval_count": 147,
494 "prompt_eval_duration": 684637583,
495 "eval_count": 16,
496 "eval_duration": 302561917,
497 });
498
499 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
500 match result.message {
501 ChatMessage::Assistant {
502 content,
503 tool_calls: Some(tool_calls),
504 images: _,
505 thinking,
506 } => {
507 assert!(content.is_empty());
508 assert!(thinking.is_none());
509
510 // When the `Option` around `id` is removed, this test should complain
511 // and be subsequently deleted in favor of `parse_tool_call()`
512 assert!(tool_calls.first().is_some_and(|call| call.id.is_none()))
513 }
514 _ => panic!("Deserialized wrong role"),
515 }
516 }
517
518 #[test]
519 fn parse_show_model() {
520 let response = serde_json::json!({
521 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
522 "details": {
523 "parent_model": "",
524 "format": "gguf",
525 "family": "llama",
526 "families": ["llama"],
527 "parameter_size": "3.2B",
528 "quantization_level": "Q4_K_M"
529 },
530 "model_info": {
531 "general.architecture": "llama",
532 "general.basename": "Llama-3.2",
533 "general.file_type": 15,
534 "general.finetune": "Instruct",
535 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
536 "general.parameter_count": 3212749888u64,
537 "general.quantization_version": 2,
538 "general.size_label": "3B",
539 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
540 "general.type": "model",
541 "llama.attention.head_count": 24,
542 "llama.attention.head_count_kv": 8,
543 "llama.attention.key_length": 128,
544 "llama.attention.layer_norm_rms_epsilon": 0.00001,
545 "llama.attention.value_length": 128,
546 "llama.block_count": 28,
547 "llama.context_length": 131072,
548 "llama.embedding_length": 3072,
549 "llama.feed_forward_length": 8192,
550 "llama.rope.dimension_count": 128,
551 "llama.rope.freq_base": 500000,
552 "llama.vocab_size": 128256,
553 "tokenizer.ggml.bos_token_id": 128000,
554 "tokenizer.ggml.eos_token_id": 128009,
555 "tokenizer.ggml.merges": null,
556 "tokenizer.ggml.model": "gpt2",
557 "tokenizer.ggml.pre": "llama-bpe",
558 "tokenizer.ggml.token_type": null,
559 "tokenizer.ggml.tokens": null
560 },
561 "tensors": [
562 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
563 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
564 ],
565 "capabilities": ["completion", "tools"],
566 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
567 });
568
569 let result: ModelShow = serde_json::from_value(response).unwrap();
570 assert!(result.supports_tools());
571 assert!(result.capabilities.contains(&"tools".to_string()));
572 assert!(result.capabilities.contains(&"completion".to_string()));
573
574 assert_eq!(result.architecture, Some("llama".to_string()));
575 assert_eq!(result.context_length, Some(131072));
576 }
577
578 #[test]
579 fn serialize_chat_request_with_images() {
580 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
581
582 let request = ChatRequest {
583 model: "llava".to_string(),
584 messages: vec![ChatMessage::User {
585 content: "What do you see in this image?".to_string(),
586 images: Some(vec![base64_image.to_string()]),
587 }],
588 stream: false,
589 keep_alive: KeepAlive::default(),
590 options: None,
591 think: None,
592 tools: vec![],
593 };
594
595 let serialized = serde_json::to_string(&request).unwrap();
596 assert!(serialized.contains("images"));
597 assert!(serialized.contains(base64_image));
598 }
599
600 #[test]
601 fn serialize_chat_request_without_images() {
602 let request = ChatRequest {
603 model: "llama3.2".to_string(),
604 messages: vec![ChatMessage::User {
605 content: "Hello, world!".to_string(),
606 images: None,
607 }],
608 stream: false,
609 keep_alive: KeepAlive::default(),
610 options: None,
611 think: None,
612 tools: vec![],
613 };
614
615 let serialized = serde_json::to_string(&request).unwrap();
616 assert!(!serialized.contains("images"));
617 }
618
619 #[test]
620 fn test_json_format_with_images() {
621 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
622
623 let request = ChatRequest {
624 model: "llava".to_string(),
625 messages: vec![ChatMessage::User {
626 content: "What do you see?".to_string(),
627 images: Some(vec![base64_image.to_string()]),
628 }],
629 stream: false,
630 keep_alive: KeepAlive::default(),
631 options: None,
632 think: None,
633 tools: vec![],
634 };
635
636 let serialized = serde_json::to_string(&request).unwrap();
637
638 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
639 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
640 assert_eq!(message_images.len(), 1);
641 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
642 }
643}