1use anyhow::{Context, Result};
2
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7pub use settings::KeepAlive;
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
13pub struct Model {
14 pub name: String,
15 pub display_name: Option<String>,
16 pub max_tokens: u64,
17 pub keep_alive: Option<KeepAlive>,
18 pub supports_tools: Option<bool>,
19 pub supports_vision: Option<bool>,
20 pub supports_thinking: Option<bool>,
21}
22
23fn get_max_tokens(_name: &str) -> u64 {
24 const DEFAULT_TOKENS: u64 = 4096;
25 DEFAULT_TOKENS
26}
27
28impl Model {
29 pub fn new(
30 name: &str,
31 display_name: Option<&str>,
32 max_tokens: Option<u64>,
33 supports_tools: Option<bool>,
34 supports_vision: Option<bool>,
35 supports_thinking: Option<bool>,
36 ) -> Self {
37 Self {
38 name: name.to_owned(),
39 display_name: display_name
40 .map(ToString::to_string)
41 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
42 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
43 keep_alive: Some(KeepAlive::indefinite()),
44 supports_tools,
45 supports_vision,
46 supports_thinking,
47 }
48 }
49
50 pub fn id(&self) -> &str {
51 &self.name
52 }
53
54 pub fn display_name(&self) -> &str {
55 self.display_name.as_ref().unwrap_or(&self.name)
56 }
57
58 pub fn max_token_count(&self) -> u64 {
59 self.max_tokens
60 }
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64#[serde(tag = "role", rename_all = "lowercase")]
65pub enum ChatMessage {
66 Assistant {
67 content: String,
68 tool_calls: Option<Vec<OllamaToolCall>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 images: Option<Vec<String>>,
71 thinking: Option<String>,
72 },
73 User {
74 content: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 images: Option<Vec<String>>,
77 },
78 System {
79 content: String,
80 },
81 Tool {
82 tool_name: String,
83 content: String,
84 },
85}
86
87#[derive(Serialize, Deserialize, Debug)]
88pub struct OllamaToolCall {
89 pub id: String,
90 pub function: OllamaFunctionCall,
91}
92
93#[derive(Serialize, Deserialize, Debug)]
94pub struct OllamaFunctionCall {
95 pub name: String,
96 pub arguments: Value,
97}
98
99#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
100pub struct OllamaFunctionTool {
101 pub name: String,
102 pub description: Option<String>,
103 pub parameters: Option<Value>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107#[serde(tag = "type", rename_all = "lowercase")]
108pub enum OllamaTool {
109 Function { function: OllamaFunctionTool },
110}
111
112#[derive(Serialize, Debug)]
113pub struct ChatRequest {
114 pub model: String,
115 pub messages: Vec<ChatMessage>,
116 pub stream: bool,
117 pub keep_alive: KeepAlive,
118 pub options: Option<ChatOptions>,
119 pub tools: Vec<OllamaTool>,
120 pub think: Option<bool>,
121}
122
123// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
124#[derive(Serialize, Default, Debug)]
125pub struct ChatOptions {
126 pub num_ctx: Option<u64>,
127 pub num_predict: Option<isize>,
128 pub stop: Option<Vec<String>>,
129 pub temperature: Option<f32>,
130 pub top_p: Option<f32>,
131}
132
133#[derive(Deserialize, Debug)]
134pub struct ChatResponseDelta {
135 pub model: String,
136 pub created_at: String,
137 pub message: ChatMessage,
138 pub done_reason: Option<String>,
139 pub done: bool,
140 pub prompt_eval_count: Option<u64>,
141 pub eval_count: Option<u64>,
142}
143
144#[derive(Serialize, Deserialize)]
145pub struct LocalModelsResponse {
146 pub models: Vec<LocalModelListing>,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct LocalModelListing {
151 pub name: String,
152 pub modified_at: String,
153 pub size: u64,
154 pub digest: String,
155 pub details: ModelDetails,
156}
157
158#[derive(Serialize, Deserialize)]
159pub struct LocalModel {
160 pub modelfile: String,
161 pub parameters: String,
162 pub template: String,
163 pub details: ModelDetails,
164}
165
166#[derive(Serialize, Deserialize)]
167pub struct ModelDetails {
168 pub format: String,
169 pub family: String,
170 pub families: Option<Vec<String>>,
171 pub parameter_size: String,
172 pub quantization_level: String,
173}
174
175#[derive(Debug)]
176pub struct ModelShow {
177 pub capabilities: Vec<String>,
178 pub context_length: Option<u64>,
179 pub architecture: Option<String>,
180}
181
182impl<'de> Deserialize<'de> for ModelShow {
183 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
184 where
185 D: serde::Deserializer<'de>,
186 {
187 use serde::de::{self, MapAccess, Visitor};
188 use std::fmt;
189
190 struct ModelShowVisitor;
191
192 impl<'de> Visitor<'de> for ModelShowVisitor {
193 type Value = ModelShow;
194
195 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
196 formatter.write_str("a ModelShow object")
197 }
198
199 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
200 where
201 A: MapAccess<'de>,
202 {
203 let mut capabilities: Vec<String> = Vec::new();
204 let mut architecture: Option<String> = None;
205 let mut context_length: Option<u64> = None;
206
207 while let Some(key) = map.next_key::<String>()? {
208 match key.as_str() {
209 "capabilities" => {
210 capabilities = map.next_value()?;
211 }
212 "model_info" => {
213 let model_info: Value = map.next_value()?;
214 if let Value::Object(obj) = model_info {
215 architecture = obj
216 .get("general.architecture")
217 .and_then(|v| v.as_str())
218 .map(String::from);
219
220 if let Some(arch) = &architecture {
221 context_length = obj
222 .get(&format!("{}.context_length", arch))
223 .and_then(|v| v.as_u64());
224 }
225 }
226 }
227 _ => {
228 let _: de::IgnoredAny = map.next_value()?;
229 }
230 }
231 }
232
233 Ok(ModelShow {
234 capabilities,
235 context_length,
236 architecture,
237 })
238 }
239 }
240
241 deserializer.deserialize_map(ModelShowVisitor)
242 }
243}
244
245impl ModelShow {
246 pub fn supports_tools(&self) -> bool {
247 // .contains expects &String, which would require an additional allocation
248 self.capabilities.iter().any(|v| v == "tools")
249 }
250
251 pub fn supports_vision(&self) -> bool {
252 self.capabilities.iter().any(|v| v == "vision")
253 }
254
255 pub fn supports_thinking(&self) -> bool {
256 self.capabilities.iter().any(|v| v == "thinking")
257 }
258}
259
260pub async fn stream_chat_completion(
261 client: &dyn HttpClient,
262 api_url: &str,
263 api_key: Option<&str>,
264 request: ChatRequest,
265) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
266 let uri = format!("{api_url}/api/chat");
267 let request = HttpRequest::builder()
268 .method(Method::POST)
269 .uri(uri)
270 .header("Content-Type", "application/json")
271 .when_some(api_key, |builder, api_key| {
272 builder.header("Authorization", format!("Bearer {api_key}"))
273 })
274 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
275
276 let mut response = client.send(request).await?;
277 if response.status().is_success() {
278 let reader = BufReader::new(response.into_body());
279
280 Ok(reader
281 .lines()
282 .map(|line| match line {
283 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
284 Err(e) => Err(e.into()),
285 })
286 .boxed())
287 } else {
288 let mut body = String::new();
289 response.body_mut().read_to_string(&mut body).await?;
290 anyhow::bail!(
291 "Failed to connect to Ollama API: {} {}",
292 response.status(),
293 body,
294 );
295 }
296}
297
298pub async fn get_models(
299 client: &dyn HttpClient,
300 api_url: &str,
301 api_key: Option<&str>,
302) -> Result<Vec<LocalModelListing>> {
303 let uri = format!("{api_url}/api/tags");
304 let request = HttpRequest::builder()
305 .method(Method::GET)
306 .uri(uri)
307 .header("Accept", "application/json")
308 .when_some(api_key, |builder, api_key| {
309 builder.header("Authorization", format!("Bearer {api_key}"))
310 })
311 .body(AsyncBody::default())?;
312
313 let mut response = client.send(request).await?;
314
315 let mut body = String::new();
316 response.body_mut().read_to_string(&mut body).await?;
317
318 anyhow::ensure!(
319 response.status().is_success(),
320 "Failed to connect to Ollama API: {} {}",
321 response.status(),
322 body,
323 );
324 let response: LocalModelsResponse =
325 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
326 Ok(response.models)
327}
328
329/// Fetch details of a model, used to determine model capabilities
330pub async fn show_model(
331 client: &dyn HttpClient,
332 api_url: &str,
333 api_key: Option<&str>,
334 model: &str,
335) -> Result<ModelShow> {
336 let uri = format!("{api_url}/api/show");
337 let request = HttpRequest::builder()
338 .method(Method::POST)
339 .uri(uri)
340 .header("Content-Type", "application/json")
341 .when_some(api_key, |builder, api_key| {
342 builder.header("Authorization", format!("Bearer {api_key}"))
343 })
344 .body(AsyncBody::from(
345 serde_json::json!({ "model": model }).to_string(),
346 ))?;
347
348 let mut response = client.send(request).await?;
349 let mut body = String::new();
350 response.body_mut().read_to_string(&mut body).await?;
351
352 anyhow::ensure!(
353 response.status().is_success(),
354 "Failed to connect to Ollama API: {} {}",
355 response.status(),
356 body,
357 );
358 let details: ModelShow = serde_json::from_str(body.as_str())?;
359 Ok(details)
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn parse_completion() {
368 let response = serde_json::json!({
369 "model": "llama3.2",
370 "created_at": "2023-12-12T14:13:43.416799Z",
371 "message": {
372 "role": "assistant",
373 "content": "Hello! How are you today?"
374 },
375 "done": true,
376 "total_duration": 5191566416u64,
377 "load_duration": 2154458,
378 "prompt_eval_count": 26,
379 "prompt_eval_duration": 383809000,
380 "eval_count": 298,
381 "eval_duration": 4799921000u64
382 });
383 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
384 }
385
386 #[test]
387 fn parse_streaming_completion() {
388 let partial = serde_json::json!({
389 "model": "llama3.2",
390 "created_at": "2023-08-04T08:52:19.385406455-07:00",
391 "message": {
392 "role": "assistant",
393 "content": "The",
394 "images": null
395 },
396 "done": false
397 });
398
399 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
400
401 let last = serde_json::json!({
402 "model": "llama3.2",
403 "created_at": "2023-08-04T19:22:45.499127Z",
404 "message": {
405 "role": "assistant",
406 "content": ""
407 },
408 "done": true,
409 "total_duration": 4883583458u64,
410 "load_duration": 1334875,
411 "prompt_eval_count": 26,
412 "prompt_eval_duration": 342546000,
413 "eval_count": 282,
414 "eval_duration": 4535599000u64
415 });
416
417 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
418 }
419
420 #[test]
421 fn parse_tool_call() {
422 let response = serde_json::json!({
423 "model": "llama3.2:3b",
424 "created_at": "2025-04-28T20:02:02.140489Z",
425 "message": {
426 "role": "assistant",
427 "content": "",
428 "tool_calls": [
429 {
430 "id": "call_llama3.2:3b_145155",
431 "function": {
432 "name": "weather",
433 "arguments": {
434 "city": "london",
435 }
436 }
437 }
438 ]
439 },
440 "done_reason": "stop",
441 "done": true,
442 "total_duration": 2758629166u64,
443 "load_duration": 1770059875,
444 "prompt_eval_count": 147,
445 "prompt_eval_duration": 684637583,
446 "eval_count": 16,
447 "eval_duration": 302561917,
448 });
449
450 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
451 match result.message {
452 ChatMessage::Assistant {
453 content,
454 tool_calls,
455 images: _,
456 thinking,
457 } => {
458 assert!(content.is_empty());
459 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
460 assert!(thinking.is_none());
461 }
462 _ => panic!("Deserialized wrong role"),
463 }
464 }
465
466 #[test]
467 fn parse_show_model() {
468 let response = serde_json::json!({
469 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
470 "details": {
471 "parent_model": "",
472 "format": "gguf",
473 "family": "llama",
474 "families": ["llama"],
475 "parameter_size": "3.2B",
476 "quantization_level": "Q4_K_M"
477 },
478 "model_info": {
479 "general.architecture": "llama",
480 "general.basename": "Llama-3.2",
481 "general.file_type": 15,
482 "general.finetune": "Instruct",
483 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
484 "general.parameter_count": 3212749888u64,
485 "general.quantization_version": 2,
486 "general.size_label": "3B",
487 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
488 "general.type": "model",
489 "llama.attention.head_count": 24,
490 "llama.attention.head_count_kv": 8,
491 "llama.attention.key_length": 128,
492 "llama.attention.layer_norm_rms_epsilon": 0.00001,
493 "llama.attention.value_length": 128,
494 "llama.block_count": 28,
495 "llama.context_length": 131072,
496 "llama.embedding_length": 3072,
497 "llama.feed_forward_length": 8192,
498 "llama.rope.dimension_count": 128,
499 "llama.rope.freq_base": 500000,
500 "llama.vocab_size": 128256,
501 "tokenizer.ggml.bos_token_id": 128000,
502 "tokenizer.ggml.eos_token_id": 128009,
503 "tokenizer.ggml.merges": null,
504 "tokenizer.ggml.model": "gpt2",
505 "tokenizer.ggml.pre": "llama-bpe",
506 "tokenizer.ggml.token_type": null,
507 "tokenizer.ggml.tokens": null
508 },
509 "tensors": [
510 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
511 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
512 ],
513 "capabilities": ["completion", "tools"],
514 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
515 });
516
517 let result: ModelShow = serde_json::from_value(response).unwrap();
518 assert!(result.supports_tools());
519 assert!(result.capabilities.contains(&"tools".to_string()));
520 assert!(result.capabilities.contains(&"completion".to_string()));
521
522 assert_eq!(result.architecture, Some("llama".to_string()));
523 assert_eq!(result.context_length, Some(131072));
524 }
525
526 #[test]
527 fn serialize_chat_request_with_images() {
528 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
529
530 let request = ChatRequest {
531 model: "llava".to_string(),
532 messages: vec![ChatMessage::User {
533 content: "What do you see in this image?".to_string(),
534 images: Some(vec![base64_image.to_string()]),
535 }],
536 stream: false,
537 keep_alive: KeepAlive::default(),
538 options: None,
539 think: None,
540 tools: vec![],
541 };
542
543 let serialized = serde_json::to_string(&request).unwrap();
544 assert!(serialized.contains("images"));
545 assert!(serialized.contains(base64_image));
546 }
547
548 #[test]
549 fn serialize_chat_request_without_images() {
550 let request = ChatRequest {
551 model: "llama3.2".to_string(),
552 messages: vec![ChatMessage::User {
553 content: "Hello, world!".to_string(),
554 images: None,
555 }],
556 stream: false,
557 keep_alive: KeepAlive::default(),
558 options: None,
559 think: None,
560 tools: vec![],
561 };
562
563 let serialized = serde_json::to_string(&request).unwrap();
564 assert!(!serialized.contains("images"));
565 }
566
567 #[test]
568 fn test_json_format_with_images() {
569 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
570
571 let request = ChatRequest {
572 model: "llava".to_string(),
573 messages: vec![ChatMessage::User {
574 content: "What do you see?".to_string(),
575 images: Some(vec![base64_image.to_string()]),
576 }],
577 stream: false,
578 keep_alive: KeepAlive::default(),
579 options: None,
580 think: None,
581 tools: vec![],
582 };
583
584 let serialized = serde_json::to_string(&request).unwrap();
585
586 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
587 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
588 assert_eq!(message_images.len(), 1);
589 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
590 }
591}