1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::KeepAlive;
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
12pub struct Model {
13 pub name: String,
14 pub display_name: Option<String>,
15 pub max_tokens: u64,
16 pub keep_alive: Option<KeepAlive>,
17 pub supports_tools: Option<bool>,
18 pub supports_vision: Option<bool>,
19 pub supports_thinking: Option<bool>,
20}
21
22fn get_max_tokens(name: &str) -> u64 {
23 /// Default context length for unknown models.
24 const DEFAULT_TOKENS: u64 = 4096;
25 /// Magic number. Lets many Ollama models work with ~16GB of ram.
26 /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
27 const MAXIMUM_TOKENS: u64 = 16384;
28
29 match name.split(':').next().unwrap() {
30 "granite-code" | "phi" | "tinyllama" => 2048,
31 "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
32 "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
33 "codellama" | "starcoder2" => 16384,
34 "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
35 | "qwen2" | "qwen2.5-coder" => 32768,
36 "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
37 | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
38 | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
39 "qwen3-coder" => 256000,
40 _ => DEFAULT_TOKENS,
41 }
42 .clamp(1, MAXIMUM_TOKENS)
43}
44
45impl Model {
46 pub fn new(
47 name: &str,
48 display_name: Option<&str>,
49 max_tokens: Option<u64>,
50 supports_tools: Option<bool>,
51 supports_vision: Option<bool>,
52 supports_thinking: Option<bool>,
53 ) -> Self {
54 Self {
55 name: name.to_owned(),
56 display_name: display_name
57 .map(ToString::to_string)
58 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
59 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
60 keep_alive: Some(KeepAlive::indefinite()),
61 supports_tools,
62 supports_vision,
63 supports_thinking,
64 }
65 }
66
67 pub fn id(&self) -> &str {
68 &self.name
69 }
70
71 pub fn display_name(&self) -> &str {
72 self.display_name.as_ref().unwrap_or(&self.name)
73 }
74
75 pub fn max_token_count(&self) -> u64 {
76 self.max_tokens
77 }
78}
79
80#[derive(Serialize, Deserialize, Debug)]
81#[serde(tag = "role", rename_all = "lowercase")]
82pub enum ChatMessage {
83 Assistant {
84 content: String,
85 tool_calls: Option<Vec<OllamaToolCall>>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 images: Option<Vec<String>>,
88 thinking: Option<String>,
89 },
90 User {
91 content: String,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 images: Option<Vec<String>>,
94 },
95 System {
96 content: String,
97 },
98 Tool {
99 tool_name: String,
100 content: String,
101 },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105#[serde(rename_all = "lowercase")]
106pub enum OllamaToolCall {
107 Function(OllamaFunctionCall),
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct OllamaFunctionCall {
112 pub name: String,
113 pub arguments: Value,
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionTool {
118 pub name: String,
119 pub description: Option<String>,
120 pub parameters: Option<Value>,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum OllamaTool {
126 Function { function: OllamaFunctionTool },
127}
128
129#[derive(Serialize, Debug)]
130pub struct ChatRequest {
131 pub model: String,
132 pub messages: Vec<ChatMessage>,
133 pub stream: bool,
134 pub keep_alive: KeepAlive,
135 pub options: Option<ChatOptions>,
136 pub tools: Vec<OllamaTool>,
137 pub think: Option<bool>,
138}
139
140// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
141#[derive(Serialize, Default, Debug)]
142pub struct ChatOptions {
143 pub num_ctx: Option<u64>,
144 pub num_predict: Option<isize>,
145 pub stop: Option<Vec<String>>,
146 pub temperature: Option<f32>,
147 pub top_p: Option<f32>,
148}
149
150#[derive(Deserialize, Debug)]
151pub struct ChatResponseDelta {
152 pub model: String,
153 pub created_at: String,
154 pub message: ChatMessage,
155 pub done_reason: Option<String>,
156 pub done: bool,
157 pub prompt_eval_count: Option<u64>,
158 pub eval_count: Option<u64>,
159}
160
161#[derive(Serialize, Deserialize)]
162pub struct LocalModelsResponse {
163 pub models: Vec<LocalModelListing>,
164}
165
166#[derive(Serialize, Deserialize)]
167pub struct LocalModelListing {
168 pub name: String,
169 pub modified_at: String,
170 pub size: u64,
171 pub digest: String,
172 pub details: ModelDetails,
173}
174
175#[derive(Serialize, Deserialize)]
176pub struct LocalModel {
177 pub modelfile: String,
178 pub parameters: String,
179 pub template: String,
180 pub details: ModelDetails,
181}
182
183#[derive(Serialize, Deserialize)]
184pub struct ModelDetails {
185 pub format: String,
186 pub family: String,
187 pub families: Option<Vec<String>>,
188 pub parameter_size: String,
189 pub quantization_level: String,
190}
191
192#[derive(Debug)]
193pub struct ModelShow {
194 pub capabilities: Vec<String>,
195 pub context_length: Option<u64>,
196 pub architecture: Option<String>,
197}
198
199impl<'de> Deserialize<'de> for ModelShow {
200 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
201 where
202 D: serde::Deserializer<'de>,
203 {
204 use serde::de::{self, MapAccess, Visitor};
205 use std::fmt;
206
207 struct ModelShowVisitor;
208
209 impl<'de> Visitor<'de> for ModelShowVisitor {
210 type Value = ModelShow;
211
212 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
213 formatter.write_str("a ModelShow object")
214 }
215
216 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
217 where
218 A: MapAccess<'de>,
219 {
220 let mut capabilities: Vec<String> = Vec::new();
221 let mut architecture: Option<String> = None;
222 let mut context_length: Option<u64> = None;
223
224 while let Some(key) = map.next_key::<String>()? {
225 match key.as_str() {
226 "capabilities" => {
227 capabilities = map.next_value()?;
228 }
229 "model_info" => {
230 let model_info: Value = map.next_value()?;
231 if let Value::Object(obj) = model_info {
232 architecture = obj
233 .get("general.architecture")
234 .and_then(|v| v.as_str())
235 .map(String::from);
236
237 if let Some(arch) = &architecture {
238 context_length = obj
239 .get(&format!("{}.context_length", arch))
240 .and_then(|v| v.as_u64());
241 }
242 }
243 }
244 _ => {
245 let _: de::IgnoredAny = map.next_value()?;
246 }
247 }
248 }
249
250 Ok(ModelShow {
251 capabilities,
252 context_length,
253 architecture,
254 })
255 }
256 }
257
258 deserializer.deserialize_map(ModelShowVisitor)
259 }
260}
261
262impl ModelShow {
263 pub fn supports_tools(&self) -> bool {
264 // .contains expects &String, which would require an additional allocation
265 self.capabilities.iter().any(|v| v == "tools")
266 }
267
268 pub fn supports_vision(&self) -> bool {
269 self.capabilities.iter().any(|v| v == "vision")
270 }
271
272 pub fn supports_thinking(&self) -> bool {
273 self.capabilities.iter().any(|v| v == "thinking")
274 }
275}
276
277pub async fn stream_chat_completion(
278 client: &dyn HttpClient,
279 api_url: &str,
280 api_key: Option<&str>,
281 request: ChatRequest,
282) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
283 let uri = format!("{api_url}/api/chat");
284 let request = HttpRequest::builder()
285 .method(Method::POST)
286 .uri(uri)
287 .header("Content-Type", "application/json")
288 .when_some(api_key, |builder, api_key| {
289 builder.header("Authorization", format!("Bearer {api_key}"))
290 })
291 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
292
293 let mut response = client.send(request).await?;
294 if response.status().is_success() {
295 let reader = BufReader::new(response.into_body());
296
297 Ok(reader
298 .lines()
299 .map(|line| match line {
300 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
301 Err(e) => Err(e.into()),
302 })
303 .boxed())
304 } else {
305 let mut body = String::new();
306 response.body_mut().read_to_string(&mut body).await?;
307 anyhow::bail!(
308 "Failed to connect to Ollama API: {} {}",
309 response.status(),
310 body,
311 );
312 }
313}
314
315pub async fn get_models(
316 client: &dyn HttpClient,
317 api_url: &str,
318 api_key: Option<&str>,
319) -> Result<Vec<LocalModelListing>> {
320 let uri = format!("{api_url}/api/tags");
321 let request = HttpRequest::builder()
322 .method(Method::GET)
323 .uri(uri)
324 .header("Accept", "application/json")
325 .when_some(api_key, |builder, api_key| {
326 builder.header("Authorization", format!("Bearer {api_key}"))
327 })
328 .body(AsyncBody::default())?;
329
330 let mut response = client.send(request).await?;
331
332 let mut body = String::new();
333 response.body_mut().read_to_string(&mut body).await?;
334
335 anyhow::ensure!(
336 response.status().is_success(),
337 "Failed to connect to Ollama API: {} {}",
338 response.status(),
339 body,
340 );
341 let response: LocalModelsResponse =
342 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
343 Ok(response.models)
344}
345
346/// Fetch details of a model, used to determine model capabilities
347pub async fn show_model(
348 client: &dyn HttpClient,
349 api_url: &str,
350 api_key: Option<&str>,
351 model: &str,
352) -> Result<ModelShow> {
353 let uri = format!("{api_url}/api/show");
354 let request = HttpRequest::builder()
355 .method(Method::POST)
356 .uri(uri)
357 .header("Content-Type", "application/json")
358 .when_some(api_key, |builder, api_key| {
359 builder.header("Authorization", format!("Bearer {api_key}"))
360 })
361 .body(AsyncBody::from(
362 serde_json::json!({ "model": model }).to_string(),
363 ))?;
364
365 let mut response = client.send(request).await?;
366 let mut body = String::new();
367 response.body_mut().read_to_string(&mut body).await?;
368
369 anyhow::ensure!(
370 response.status().is_success(),
371 "Failed to connect to Ollama API: {} {}",
372 response.status(),
373 body,
374 );
375 let details: ModelShow = serde_json::from_str(body.as_str())?;
376 Ok(details)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn parse_completion() {
385 let response = serde_json::json!({
386 "model": "llama3.2",
387 "created_at": "2023-12-12T14:13:43.416799Z",
388 "message": {
389 "role": "assistant",
390 "content": "Hello! How are you today?"
391 },
392 "done": true,
393 "total_duration": 5191566416u64,
394 "load_duration": 2154458,
395 "prompt_eval_count": 26,
396 "prompt_eval_duration": 383809000,
397 "eval_count": 298,
398 "eval_duration": 4799921000u64
399 });
400 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
401 }
402
403 #[test]
404 fn parse_streaming_completion() {
405 let partial = serde_json::json!({
406 "model": "llama3.2",
407 "created_at": "2023-08-04T08:52:19.385406455-07:00",
408 "message": {
409 "role": "assistant",
410 "content": "The",
411 "images": null
412 },
413 "done": false
414 });
415
416 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
417
418 let last = serde_json::json!({
419 "model": "llama3.2",
420 "created_at": "2023-08-04T19:22:45.499127Z",
421 "message": {
422 "role": "assistant",
423 "content": ""
424 },
425 "done": true,
426 "total_duration": 4883583458u64,
427 "load_duration": 1334875,
428 "prompt_eval_count": 26,
429 "prompt_eval_duration": 342546000,
430 "eval_count": 282,
431 "eval_duration": 4535599000u64
432 });
433
434 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
435 }
436
437 #[test]
438 fn parse_tool_call() {
439 let response = serde_json::json!({
440 "model": "llama3.2:3b",
441 "created_at": "2025-04-28T20:02:02.140489Z",
442 "message": {
443 "role": "assistant",
444 "content": "",
445 "tool_calls": [
446 {
447 "function": {
448 "name": "weather",
449 "arguments": {
450 "city": "london",
451 }
452 }
453 }
454 ]
455 },
456 "done_reason": "stop",
457 "done": true,
458 "total_duration": 2758629166u64,
459 "load_duration": 1770059875,
460 "prompt_eval_count": 147,
461 "prompt_eval_duration": 684637583,
462 "eval_count": 16,
463 "eval_duration": 302561917,
464 });
465
466 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
467 match result.message {
468 ChatMessage::Assistant {
469 content,
470 tool_calls,
471 images: _,
472 thinking,
473 } => {
474 assert!(content.is_empty());
475 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
476 assert!(thinking.is_none());
477 }
478 _ => panic!("Deserialized wrong role"),
479 }
480 }
481
482 #[test]
483 fn parse_show_model() {
484 let response = serde_json::json!({
485 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
486 "details": {
487 "parent_model": "",
488 "format": "gguf",
489 "family": "llama",
490 "families": ["llama"],
491 "parameter_size": "3.2B",
492 "quantization_level": "Q4_K_M"
493 },
494 "model_info": {
495 "general.architecture": "llama",
496 "general.basename": "Llama-3.2",
497 "general.file_type": 15,
498 "general.finetune": "Instruct",
499 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
500 "general.parameter_count": 3212749888u64,
501 "general.quantization_version": 2,
502 "general.size_label": "3B",
503 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
504 "general.type": "model",
505 "llama.attention.head_count": 24,
506 "llama.attention.head_count_kv": 8,
507 "llama.attention.key_length": 128,
508 "llama.attention.layer_norm_rms_epsilon": 0.00001,
509 "llama.attention.value_length": 128,
510 "llama.block_count": 28,
511 "llama.context_length": 131072,
512 "llama.embedding_length": 3072,
513 "llama.feed_forward_length": 8192,
514 "llama.rope.dimension_count": 128,
515 "llama.rope.freq_base": 500000,
516 "llama.vocab_size": 128256,
517 "tokenizer.ggml.bos_token_id": 128000,
518 "tokenizer.ggml.eos_token_id": 128009,
519 "tokenizer.ggml.merges": null,
520 "tokenizer.ggml.model": "gpt2",
521 "tokenizer.ggml.pre": "llama-bpe",
522 "tokenizer.ggml.token_type": null,
523 "tokenizer.ggml.tokens": null
524 },
525 "tensors": [
526 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
527 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
528 ],
529 "capabilities": ["completion", "tools"],
530 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
531 });
532
533 let result: ModelShow = serde_json::from_value(response).unwrap();
534 assert!(result.supports_tools());
535 assert!(result.capabilities.contains(&"tools".to_string()));
536 assert!(result.capabilities.contains(&"completion".to_string()));
537
538 assert_eq!(result.architecture, Some("llama".to_string()));
539 assert_eq!(result.context_length, Some(131072));
540 }
541
542 #[test]
543 fn serialize_chat_request_with_images() {
544 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
545
546 let request = ChatRequest {
547 model: "llava".to_string(),
548 messages: vec![ChatMessage::User {
549 content: "What do you see in this image?".to_string(),
550 images: Some(vec![base64_image.to_string()]),
551 }],
552 stream: false,
553 keep_alive: KeepAlive::default(),
554 options: None,
555 think: None,
556 tools: vec![],
557 };
558
559 let serialized = serde_json::to_string(&request).unwrap();
560 assert!(serialized.contains("images"));
561 assert!(serialized.contains(base64_image));
562 }
563
564 #[test]
565 fn serialize_chat_request_without_images() {
566 let request = ChatRequest {
567 model: "llama3.2".to_string(),
568 messages: vec![ChatMessage::User {
569 content: "Hello, world!".to_string(),
570 images: None,
571 }],
572 stream: false,
573 keep_alive: KeepAlive::default(),
574 options: None,
575 think: None,
576 tools: vec![],
577 };
578
579 let serialized = serde_json::to_string(&request).unwrap();
580 assert!(!serialized.contains("images"));
581 }
582
583 #[test]
584 fn test_json_format_with_images() {
585 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
586
587 let request = ChatRequest {
588 model: "llava".to_string(),
589 messages: vec![ChatMessage::User {
590 content: "What do you see?".to_string(),
591 images: Some(vec![base64_image.to_string()]),
592 }],
593 stream: false,
594 keep_alive: KeepAlive::default(),
595 options: None,
596 think: None,
597 tools: vec![],
598 };
599
600 let serialized = serde_json::to_string(&request).unwrap();
601
602 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
603 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
604 assert_eq!(message_images.len(), 1);
605 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
606 }
607}