1use anyhow::{Context, Result};
2
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7pub use settings::KeepAlive;
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
13pub struct Model {
14 pub name: String,
15 pub display_name: Option<String>,
16 pub max_tokens: u64,
17 pub keep_alive: Option<KeepAlive>,
18 pub supports_tools: Option<bool>,
19 pub supports_vision: Option<bool>,
20 pub supports_thinking: Option<bool>,
21}
22
23fn get_max_tokens(_name: &str) -> u64 {
24 const DEFAULT_TOKENS: u64 = 4096;
25 DEFAULT_TOKENS
26}
27
28impl Model {
29 pub fn new(
30 name: &str,
31 display_name: Option<&str>,
32 max_tokens: Option<u64>,
33 supports_tools: Option<bool>,
34 supports_vision: Option<bool>,
35 supports_thinking: Option<bool>,
36 ) -> Self {
37 Self {
38 name: name.to_owned(),
39 display_name: display_name
40 .map(ToString::to_string)
41 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
42 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
43 keep_alive: Some(KeepAlive::indefinite()),
44 supports_tools,
45 supports_vision,
46 supports_thinking,
47 }
48 }
49
50 pub fn id(&self) -> &str {
51 &self.name
52 }
53
54 pub fn display_name(&self) -> &str {
55 self.display_name.as_ref().unwrap_or(&self.name)
56 }
57
58 pub fn max_token_count(&self) -> u64 {
59 self.max_tokens
60 }
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64#[serde(tag = "role", rename_all = "lowercase")]
65pub enum ChatMessage {
66 Assistant {
67 content: String,
68 tool_calls: Option<Vec<OllamaToolCall>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 images: Option<Vec<String>>,
71 thinking: Option<String>,
72 },
73 User {
74 content: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 images: Option<Vec<String>>,
77 },
78 System {
79 content: String,
80 },
81 Tool {
82 tool_name: String,
83 content: String,
84 },
85}
86
87#[derive(Serialize, Deserialize, Debug)]
88pub struct OllamaToolCall {
89 pub id: String,
90 pub function: OllamaFunctionCall,
91}
92
93#[derive(Serialize, Deserialize, Debug)]
94pub struct OllamaFunctionCall {
95 pub name: String,
96 pub arguments: Value,
97}
98
99#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
100pub struct OllamaFunctionTool {
101 pub name: String,
102 pub description: Option<String>,
103 pub parameters: Option<Value>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107#[serde(tag = "type", rename_all = "lowercase")]
108pub enum OllamaTool {
109 Function { function: OllamaFunctionTool },
110}
111
112#[derive(Serialize, Debug)]
113pub struct ChatRequest {
114 pub model: String,
115 pub messages: Vec<ChatMessage>,
116 pub stream: bool,
117 pub keep_alive: KeepAlive,
118 pub options: Option<ChatOptions>,
119 pub tools: Vec<OllamaTool>,
120 pub think: Option<bool>,
121}
122
123// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
124#[derive(Serialize, Default, Debug)]
125pub struct ChatOptions {
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub num_ctx: Option<u64>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub num_predict: Option<isize>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 pub stop: Option<Vec<String>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub temperature: Option<f32>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub top_p: Option<f32>,
136}
137
138#[derive(Deserialize, Debug)]
139pub struct ChatResponseDelta {
140 pub model: String,
141 pub created_at: String,
142 pub message: ChatMessage,
143 pub done_reason: Option<String>,
144 pub done: bool,
145 pub prompt_eval_count: Option<u64>,
146 pub eval_count: Option<u64>,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct LocalModelsResponse {
151 pub models: Vec<LocalModelListing>,
152}
153
154#[derive(Serialize, Deserialize)]
155pub struct LocalModelListing {
156 pub name: String,
157 pub modified_at: String,
158 pub size: u64,
159 pub digest: String,
160 pub details: ModelDetails,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct LocalModel {
165 pub modelfile: String,
166 pub parameters: String,
167 pub template: String,
168 pub details: ModelDetails,
169}
170
171#[derive(Serialize, Deserialize)]
172pub struct ModelDetails {
173 pub format: String,
174 pub family: String,
175 pub families: Option<Vec<String>>,
176 pub parameter_size: String,
177 pub quantization_level: String,
178}
179
180#[derive(Debug)]
181pub struct ModelShow {
182 pub capabilities: Vec<String>,
183 pub context_length: Option<u64>,
184 pub architecture: Option<String>,
185}
186
187impl<'de> Deserialize<'de> for ModelShow {
188 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
189 where
190 D: serde::Deserializer<'de>,
191 {
192 use serde::de::{self, MapAccess, Visitor};
193 use std::fmt;
194
195 struct ModelShowVisitor;
196
197 impl<'de> Visitor<'de> for ModelShowVisitor {
198 type Value = ModelShow;
199
200 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
201 formatter.write_str("a ModelShow object")
202 }
203
204 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
205 where
206 A: MapAccess<'de>,
207 {
208 let mut capabilities: Vec<String> = Vec::new();
209 let mut architecture: Option<String> = None;
210 let mut context_length: Option<u64> = None;
211
212 while let Some(key) = map.next_key::<String>()? {
213 match key.as_str() {
214 "capabilities" => {
215 capabilities = map.next_value()?;
216 }
217 "model_info" => {
218 let model_info: Value = map.next_value()?;
219 if let Value::Object(obj) = model_info {
220 architecture = obj
221 .get("general.architecture")
222 .and_then(|v| v.as_str())
223 .map(String::from);
224
225 if let Some(arch) = &architecture {
226 context_length = obj
227 .get(&format!("{}.context_length", arch))
228 .and_then(|v| v.as_u64());
229 }
230 }
231 }
232 _ => {
233 let _: de::IgnoredAny = map.next_value()?;
234 }
235 }
236 }
237
238 Ok(ModelShow {
239 capabilities,
240 context_length,
241 architecture,
242 })
243 }
244 }
245
246 deserializer.deserialize_map(ModelShowVisitor)
247 }
248}
249
250impl ModelShow {
251 pub fn supports_tools(&self) -> bool {
252 // .contains expects &String, which would require an additional allocation
253 self.capabilities.iter().any(|v| v == "tools")
254 }
255
256 pub fn supports_vision(&self) -> bool {
257 self.capabilities.iter().any(|v| v == "vision")
258 }
259
260 pub fn supports_thinking(&self) -> bool {
261 self.capabilities.iter().any(|v| v == "thinking")
262 }
263}
264
265pub async fn stream_chat_completion(
266 client: &dyn HttpClient,
267 api_url: &str,
268 api_key: Option<&str>,
269 request: ChatRequest,
270) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
271 let uri = format!("{api_url}/api/chat");
272 let request = HttpRequest::builder()
273 .method(Method::POST)
274 .uri(uri)
275 .header("Content-Type", "application/json")
276 .when_some(api_key, |builder, api_key| {
277 builder.header("Authorization", format!("Bearer {api_key}"))
278 })
279 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
280
281 let mut response = client.send(request).await?;
282 if response.status().is_success() {
283 let reader = BufReader::new(response.into_body());
284
285 Ok(reader
286 .lines()
287 .map(|line| match line {
288 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
289 Err(e) => Err(e.into()),
290 })
291 .boxed())
292 } else {
293 let mut body = String::new();
294 response.body_mut().read_to_string(&mut body).await?;
295 anyhow::bail!(
296 "Failed to connect to Ollama API: {} {}",
297 response.status(),
298 body,
299 );
300 }
301}
302
303pub async fn get_models(
304 client: &dyn HttpClient,
305 api_url: &str,
306 api_key: Option<&str>,
307) -> Result<Vec<LocalModelListing>> {
308 let uri = format!("{api_url}/api/tags");
309 let request = HttpRequest::builder()
310 .method(Method::GET)
311 .uri(uri)
312 .header("Accept", "application/json")
313 .when_some(api_key, |builder, api_key| {
314 builder.header("Authorization", format!("Bearer {api_key}"))
315 })
316 .body(AsyncBody::default())?;
317
318 let mut response = client.send(request).await?;
319
320 let mut body = String::new();
321 response.body_mut().read_to_string(&mut body).await?;
322
323 anyhow::ensure!(
324 response.status().is_success(),
325 "Failed to connect to Ollama API: {} {}",
326 response.status(),
327 body,
328 );
329 let response: LocalModelsResponse =
330 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
331 Ok(response.models)
332}
333
334/// Fetch details of a model, used to determine model capabilities
335pub async fn show_model(
336 client: &dyn HttpClient,
337 api_url: &str,
338 api_key: Option<&str>,
339 model: &str,
340) -> Result<ModelShow> {
341 let uri = format!("{api_url}/api/show");
342 let request = HttpRequest::builder()
343 .method(Method::POST)
344 .uri(uri)
345 .header("Content-Type", "application/json")
346 .when_some(api_key, |builder, api_key| {
347 builder.header("Authorization", format!("Bearer {api_key}"))
348 })
349 .body(AsyncBody::from(
350 serde_json::json!({ "model": model }).to_string(),
351 ))?;
352
353 let mut response = client.send(request).await?;
354 let mut body = String::new();
355 response.body_mut().read_to_string(&mut body).await?;
356
357 anyhow::ensure!(
358 response.status().is_success(),
359 "Failed to connect to Ollama API: {} {}",
360 response.status(),
361 body,
362 );
363 let details: ModelShow = serde_json::from_str(body.as_str())?;
364 Ok(details)
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn parse_completion() {
373 let response = serde_json::json!({
374 "model": "llama3.2",
375 "created_at": "2023-12-12T14:13:43.416799Z",
376 "message": {
377 "role": "assistant",
378 "content": "Hello! How are you today?"
379 },
380 "done": true,
381 "total_duration": 5191566416u64,
382 "load_duration": 2154458,
383 "prompt_eval_count": 26,
384 "prompt_eval_duration": 383809000,
385 "eval_count": 298,
386 "eval_duration": 4799921000u64
387 });
388 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
389 }
390
391 #[test]
392 fn parse_streaming_completion() {
393 let partial = serde_json::json!({
394 "model": "llama3.2",
395 "created_at": "2023-08-04T08:52:19.385406455-07:00",
396 "message": {
397 "role": "assistant",
398 "content": "The",
399 "images": null
400 },
401 "done": false
402 });
403
404 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
405
406 let last = serde_json::json!({
407 "model": "llama3.2",
408 "created_at": "2023-08-04T19:22:45.499127Z",
409 "message": {
410 "role": "assistant",
411 "content": ""
412 },
413 "done": true,
414 "total_duration": 4883583458u64,
415 "load_duration": 1334875,
416 "prompt_eval_count": 26,
417 "prompt_eval_duration": 342546000,
418 "eval_count": 282,
419 "eval_duration": 4535599000u64
420 });
421
422 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
423 }
424
425 #[test]
426 fn parse_tool_call() {
427 let response = serde_json::json!({
428 "model": "llama3.2:3b",
429 "created_at": "2025-04-28T20:02:02.140489Z",
430 "message": {
431 "role": "assistant",
432 "content": "",
433 "tool_calls": [
434 {
435 "id": "call_llama3.2:3b_145155",
436 "function": {
437 "name": "weather",
438 "arguments": {
439 "city": "london",
440 }
441 }
442 }
443 ]
444 },
445 "done_reason": "stop",
446 "done": true,
447 "total_duration": 2758629166u64,
448 "load_duration": 1770059875,
449 "prompt_eval_count": 147,
450 "prompt_eval_duration": 684637583,
451 "eval_count": 16,
452 "eval_duration": 302561917,
453 });
454
455 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
456 match result.message {
457 ChatMessage::Assistant {
458 content,
459 tool_calls,
460 images: _,
461 thinking,
462 } => {
463 assert!(content.is_empty());
464 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
465 assert!(thinking.is_none());
466 }
467 _ => panic!("Deserialized wrong role"),
468 }
469 }
470
471 #[test]
472 fn parse_show_model() {
473 let response = serde_json::json!({
474 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
475 "details": {
476 "parent_model": "",
477 "format": "gguf",
478 "family": "llama",
479 "families": ["llama"],
480 "parameter_size": "3.2B",
481 "quantization_level": "Q4_K_M"
482 },
483 "model_info": {
484 "general.architecture": "llama",
485 "general.basename": "Llama-3.2",
486 "general.file_type": 15,
487 "general.finetune": "Instruct",
488 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
489 "general.parameter_count": 3212749888u64,
490 "general.quantization_version": 2,
491 "general.size_label": "3B",
492 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
493 "general.type": "model",
494 "llama.attention.head_count": 24,
495 "llama.attention.head_count_kv": 8,
496 "llama.attention.key_length": 128,
497 "llama.attention.layer_norm_rms_epsilon": 0.00001,
498 "llama.attention.value_length": 128,
499 "llama.block_count": 28,
500 "llama.context_length": 131072,
501 "llama.embedding_length": 3072,
502 "llama.feed_forward_length": 8192,
503 "llama.rope.dimension_count": 128,
504 "llama.rope.freq_base": 500000,
505 "llama.vocab_size": 128256,
506 "tokenizer.ggml.bos_token_id": 128000,
507 "tokenizer.ggml.eos_token_id": 128009,
508 "tokenizer.ggml.merges": null,
509 "tokenizer.ggml.model": "gpt2",
510 "tokenizer.ggml.pre": "llama-bpe",
511 "tokenizer.ggml.token_type": null,
512 "tokenizer.ggml.tokens": null
513 },
514 "tensors": [
515 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
516 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
517 ],
518 "capabilities": ["completion", "tools"],
519 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
520 });
521
522 let result: ModelShow = serde_json::from_value(response).unwrap();
523 assert!(result.supports_tools());
524 assert!(result.capabilities.contains(&"tools".to_string()));
525 assert!(result.capabilities.contains(&"completion".to_string()));
526
527 assert_eq!(result.architecture, Some("llama".to_string()));
528 assert_eq!(result.context_length, Some(131072));
529 }
530
531 #[test]
532 fn serialize_chat_request_with_images() {
533 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
534
535 let request = ChatRequest {
536 model: "llava".to_string(),
537 messages: vec![ChatMessage::User {
538 content: "What do you see in this image?".to_string(),
539 images: Some(vec![base64_image.to_string()]),
540 }],
541 stream: false,
542 keep_alive: KeepAlive::default(),
543 options: None,
544 think: None,
545 tools: vec![],
546 };
547
548 let serialized = serde_json::to_string(&request).unwrap();
549 assert!(serialized.contains("images"));
550 assert!(serialized.contains(base64_image));
551 }
552
553 #[test]
554 fn serialize_chat_request_without_images() {
555 let request = ChatRequest {
556 model: "llama3.2".to_string(),
557 messages: vec![ChatMessage::User {
558 content: "Hello, world!".to_string(),
559 images: None,
560 }],
561 stream: false,
562 keep_alive: KeepAlive::default(),
563 options: None,
564 think: None,
565 tools: vec![],
566 };
567
568 let serialized = serde_json::to_string(&request).unwrap();
569 assert!(!serialized.contains("images"));
570 }
571
572 #[test]
573 fn test_json_format_with_images() {
574 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
575
576 let request = ChatRequest {
577 model: "llava".to_string(),
578 messages: vec![ChatMessage::User {
579 content: "What do you see?".to_string(),
580 images: Some(vec![base64_image.to_string()]),
581 }],
582 stream: false,
583 keep_alive: KeepAlive::default(),
584 options: None,
585 think: None,
586 tools: vec![],
587 };
588
589 let serialized = serde_json::to_string(&request).unwrap();
590
591 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
592 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
593 assert_eq!(message_images.len(), 1);
594 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
595 }
596
597 #[test]
598 fn test_chat_options_serialization() {
599 // When stop is None, it should not appear in JSON at all
600 // This allows Ollama to use the model's default stop tokens
601 let options_no_stop = ChatOptions {
602 num_ctx: Some(4096),
603 stop: None,
604 temperature: Some(0.7),
605 ..Default::default()
606 };
607 let serialized = serde_json::to_string(&options_no_stop).unwrap();
608 assert!(
609 !serialized.contains("stop"),
610 "stop should not be in JSON when None"
611 );
612 assert!(serialized.contains("num_ctx"));
613 assert!(serialized.contains("temperature"));
614
615 // When stop has values, they should be serialized
616 let options_with_stop = ChatOptions {
617 stop: Some(vec!["<|eot_id|>".to_string()]),
618 ..Default::default()
619 };
620 let serialized = serde_json::to_string(&options_with_stop).unwrap();
621 assert!(serialized.contains("stop"));
622 assert!(serialized.contains("<|eot_id|>"));
623
624 // All None options should result in empty object
625 let options_all_none = ChatOptions::default();
626 let serialized = serde_json::to_string(&options_all_none).unwrap();
627 assert_eq!(serialized, "{}");
628 }
629
630 #[test]
631 fn test_chat_request_with_stop_tokens() {
632 let request = ChatRequest {
633 model: "rnj-1:8b".to_string(),
634 messages: vec![ChatMessage::User {
635 content: "Hello".to_string(),
636 images: None,
637 }],
638 stream: true,
639 keep_alive: KeepAlive::default(),
640 options: Some(ChatOptions {
641 stop: Some(vec!["<|eot_id|>".to_string(), "<|end|>".to_string()]),
642 ..Default::default()
643 }),
644 think: None,
645 tools: vec![],
646 };
647
648 let serialized = serde_json::to_string(&request).unwrap();
649 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
650
651 let stop = parsed["options"]["stop"].as_array().unwrap();
652 assert_eq!(stop.len(), 2);
653 assert_eq!(stop[0].as_str().unwrap(), "<|eot_id|>");
654 assert_eq!(stop[1].as_str().unwrap(), "<|end|>");
655 }
656
657 #[test]
658 fn test_chat_request_without_stop_tokens_omits_field() {
659 // This tests the fix for issue #47798
660 // When no stop tokens are provided, the field should be omitted
661 // so Ollama uses the model's default stop tokens from Modelfile
662 let request = ChatRequest {
663 model: "rnj-1:8b".to_string(),
664 messages: vec![ChatMessage::User {
665 content: "Hello".to_string(),
666 images: None,
667 }],
668 stream: true,
669 keep_alive: KeepAlive::default(),
670 options: Some(ChatOptions {
671 num_ctx: Some(4096),
672 stop: None, // No stop tokens - should be omitted from JSON
673 ..Default::default()
674 }),
675 think: None,
676 tools: vec![],
677 };
678
679 let serialized = serde_json::to_string(&request).unwrap();
680
681 // The key check: "stop" should not appear in the serialized JSON
682 assert!(
683 !serialized.contains("\"stop\""),
684 "stop field should be omitted when None, got: {}",
685 serialized
686 );
687 }
688}