1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::KeepAlive;
7use std::time::Duration;
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
13pub struct Model {
14 pub name: String,
15 pub display_name: Option<String>,
16 pub max_tokens: u64,
17 pub keep_alive: Option<KeepAlive>,
18 pub supports_tools: Option<bool>,
19 pub supports_vision: Option<bool>,
20 pub supports_thinking: Option<bool>,
21}
22
23fn get_max_tokens(name: &str) -> u64 {
24 /// Default context length for unknown models.
25 const DEFAULT_TOKENS: u64 = 4096;
26 /// Magic number. Lets many Ollama models work with ~16GB of ram.
27 /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
28 const MAXIMUM_TOKENS: u64 = 16384;
29
30 match name.split(':').next().unwrap() {
31 "granite-code" | "phi" | "tinyllama" => 2048,
32 "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
33 "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
34 "codellama" | "starcoder2" => 16384,
35 "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
36 | "qwen2" | "qwen2.5-coder" => 32768,
37 "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
38 | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
39 | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
40 _ => DEFAULT_TOKENS,
41 }
42 .clamp(1, MAXIMUM_TOKENS)
43}
44
45impl Model {
46 pub fn new(
47 name: &str,
48 display_name: Option<&str>,
49 max_tokens: Option<u64>,
50 supports_tools: Option<bool>,
51 supports_vision: Option<bool>,
52 supports_thinking: Option<bool>,
53 ) -> Self {
54 Self {
55 name: name.to_owned(),
56 display_name: display_name
57 .map(ToString::to_string)
58 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
59 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
60 keep_alive: Some(KeepAlive::indefinite()),
61 supports_tools,
62 supports_vision,
63 supports_thinking,
64 }
65 }
66
67 pub fn id(&self) -> &str {
68 &self.name
69 }
70
71 pub fn display_name(&self) -> &str {
72 self.display_name.as_ref().unwrap_or(&self.name)
73 }
74
75 pub fn max_token_count(&self) -> u64 {
76 self.max_tokens
77 }
78}
79
80#[derive(Serialize, Deserialize, Debug)]
81#[serde(tag = "role", rename_all = "lowercase")]
82pub enum ChatMessage {
83 Assistant {
84 content: String,
85 tool_calls: Option<Vec<OllamaToolCall>>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 images: Option<Vec<String>>,
88 thinking: Option<String>,
89 },
90 User {
91 content: String,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 images: Option<Vec<String>>,
94 },
95 System {
96 content: String,
97 },
98 Tool {
99 tool_name: String,
100 content: String,
101 },
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105#[serde(rename_all = "lowercase")]
106pub enum OllamaToolCall {
107 Function(OllamaFunctionCall),
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct OllamaFunctionCall {
112 pub name: String,
113 pub arguments: Value,
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionTool {
118 pub name: String,
119 pub description: Option<String>,
120 pub parameters: Option<Value>,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum OllamaTool {
126 Function { function: OllamaFunctionTool },
127}
128
129#[derive(Serialize, Debug)]
130pub struct ChatRequest {
131 pub model: String,
132 pub messages: Vec<ChatMessage>,
133 pub stream: bool,
134 pub keep_alive: KeepAlive,
135 pub options: Option<ChatOptions>,
136 pub tools: Vec<OllamaTool>,
137 pub think: Option<bool>,
138}
139
140impl ChatRequest {
141 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
142 self.stream = false;
143 self.tools = tools;
144 self
145 }
146}
147
148// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
149#[derive(Serialize, Default, Debug)]
150pub struct ChatOptions {
151 pub num_ctx: Option<u64>,
152 pub num_predict: Option<isize>,
153 pub stop: Option<Vec<String>>,
154 pub temperature: Option<f32>,
155 pub top_p: Option<f32>,
156}
157
158#[derive(Deserialize, Debug)]
159pub struct ChatResponseDelta {
160 #[allow(unused)]
161 pub model: String,
162 #[allow(unused)]
163 pub created_at: String,
164 pub message: ChatMessage,
165 #[allow(unused)]
166 pub done_reason: Option<String>,
167 #[allow(unused)]
168 pub done: bool,
169 pub prompt_eval_count: Option<u64>,
170 pub eval_count: Option<u64>,
171}
172
173#[derive(Serialize, Deserialize)]
174pub struct LocalModelsResponse {
175 pub models: Vec<LocalModelListing>,
176}
177
178#[derive(Serialize, Deserialize)]
179pub struct LocalModelListing {
180 pub name: String,
181 pub modified_at: String,
182 pub size: u64,
183 pub digest: String,
184 pub details: ModelDetails,
185}
186
187#[derive(Serialize, Deserialize)]
188pub struct LocalModel {
189 pub modelfile: String,
190 pub parameters: String,
191 pub template: String,
192 pub details: ModelDetails,
193}
194
195#[derive(Serialize, Deserialize)]
196pub struct ModelDetails {
197 pub format: String,
198 pub family: String,
199 pub families: Option<Vec<String>>,
200 pub parameter_size: String,
201 pub quantization_level: String,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct ModelShow {
206 #[serde(default)]
207 pub capabilities: Vec<String>,
208}
209
210impl ModelShow {
211 pub fn supports_tools(&self) -> bool {
212 // .contains expects &String, which would require an additional allocation
213 self.capabilities.iter().any(|v| v == "tools")
214 }
215
216 pub fn supports_vision(&self) -> bool {
217 self.capabilities.iter().any(|v| v == "vision")
218 }
219
220 pub fn supports_thinking(&self) -> bool {
221 self.capabilities.iter().any(|v| v == "thinking")
222 }
223}
224
225pub async fn complete(
226 client: &dyn HttpClient,
227 api_url: &str,
228 request: ChatRequest,
229) -> Result<ChatResponseDelta> {
230 let uri = format!("{api_url}/api/chat");
231 let request_builder = HttpRequest::builder()
232 .method(Method::POST)
233 .uri(uri)
234 .header("Content-Type", "application/json");
235
236 let serialized_request = serde_json::to_string(&request)?;
237 let request = request_builder.body(AsyncBody::from(serialized_request))?;
238
239 let mut response = client.send(request).await?;
240
241 let mut body = Vec::new();
242 response.body_mut().read_to_end(&mut body).await?;
243
244 if response.status().is_success() {
245 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
246 Ok(response_message)
247 } else {
248 let body_str = std::str::from_utf8(&body)?;
249 anyhow::bail!(
250 "Failed to connect to API: {} {}",
251 response.status(),
252 body_str
253 );
254 }
255}
256
257pub async fn stream_chat_completion(
258 client: &dyn HttpClient,
259 api_url: &str,
260 api_key: Option<&str>,
261 request: ChatRequest,
262) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
263 let uri = format!("{api_url}/api/chat");
264 let mut request_builder = http::Request::builder()
265 .method(Method::POST)
266 .uri(uri)
267 .header("Content-Type", "application/json");
268
269 if let Some(api_key) = api_key {
270 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
271 }
272
273 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
274 let mut response = client.send(request).await?;
275 if response.status().is_success() {
276 let reader = BufReader::new(response.into_body());
277
278 Ok(reader
279 .lines()
280 .map(|line| match line {
281 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
282 Err(e) => Err(e.into()),
283 })
284 .boxed())
285 } else {
286 let mut body = String::new();
287 response.body_mut().read_to_string(&mut body).await?;
288 anyhow::bail!(
289 "Failed to connect to Ollama API: {} {}",
290 response.status(),
291 body,
292 );
293 }
294}
295
296pub async fn get_models(
297 client: &dyn HttpClient,
298 api_url: &str,
299 api_key: Option<&str>,
300 _: Option<Duration>,
301) -> Result<Vec<LocalModelListing>> {
302 let uri = format!("{api_url}/api/tags");
303 let mut request_builder = HttpRequest::builder()
304 .method(Method::GET)
305 .uri(uri)
306 .header("Accept", "application/json");
307
308 if let Some(api_key) = api_key {
309 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
310 }
311
312 let request = request_builder.body(AsyncBody::default())?;
313
314 let mut response = client.send(request).await?;
315
316 let mut body = String::new();
317 response.body_mut().read_to_string(&mut body).await?;
318
319 anyhow::ensure!(
320 response.status().is_success(),
321 "Failed to connect to Ollama API: {} {}",
322 response.status(),
323 body,
324 );
325 let response: LocalModelsResponse =
326 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
327 Ok(response.models)
328}
329
330/// Fetch details of a model, used to determine model capabilities
331pub async fn show_model(
332 client: &dyn HttpClient,
333 api_url: &str,
334 api_key: Option<&str>,
335 model: &str,
336) -> Result<ModelShow> {
337 let uri = format!("{api_url}/api/show");
338 let mut request_builder = HttpRequest::builder()
339 .method(Method::POST)
340 .uri(uri)
341 .header("Content-Type", "application/json");
342
343 if let Some(api_key) = api_key {
344 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
345 }
346
347 let request = request_builder.body(AsyncBody::from(
348 serde_json::json!({ "model": model }).to_string(),
349 ))?;
350
351 let mut response = client.send(request).await?;
352 let mut body = String::new();
353 response.body_mut().read_to_string(&mut body).await?;
354
355 anyhow::ensure!(
356 response.status().is_success(),
357 "Failed to connect to Ollama API: {} {}",
358 response.status(),
359 body,
360 );
361 let details: ModelShow = serde_json::from_str(body.as_str())?;
362 Ok(details)
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn parse_completion() {
371 let response = serde_json::json!({
372 "model": "llama3.2",
373 "created_at": "2023-12-12T14:13:43.416799Z",
374 "message": {
375 "role": "assistant",
376 "content": "Hello! How are you today?"
377 },
378 "done": true,
379 "total_duration": 5191566416u64,
380 "load_duration": 2154458,
381 "prompt_eval_count": 26,
382 "prompt_eval_duration": 383809000,
383 "eval_count": 298,
384 "eval_duration": 4799921000u64
385 });
386 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
387 }
388
389 #[test]
390 fn parse_streaming_completion() {
391 let partial = serde_json::json!({
392 "model": "llama3.2",
393 "created_at": "2023-08-04T08:52:19.385406455-07:00",
394 "message": {
395 "role": "assistant",
396 "content": "The",
397 "images": null
398 },
399 "done": false
400 });
401
402 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
403
404 let last = serde_json::json!({
405 "model": "llama3.2",
406 "created_at": "2023-08-04T19:22:45.499127Z",
407 "message": {
408 "role": "assistant",
409 "content": ""
410 },
411 "done": true,
412 "total_duration": 4883583458u64,
413 "load_duration": 1334875,
414 "prompt_eval_count": 26,
415 "prompt_eval_duration": 342546000,
416 "eval_count": 282,
417 "eval_duration": 4535599000u64
418 });
419
420 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
421 }
422
423 #[test]
424 fn parse_tool_call() {
425 let response = serde_json::json!({
426 "model": "llama3.2:3b",
427 "created_at": "2025-04-28T20:02:02.140489Z",
428 "message": {
429 "role": "assistant",
430 "content": "",
431 "tool_calls": [
432 {
433 "function": {
434 "name": "weather",
435 "arguments": {
436 "city": "london",
437 }
438 }
439 }
440 ]
441 },
442 "done_reason": "stop",
443 "done": true,
444 "total_duration": 2758629166u64,
445 "load_duration": 1770059875,
446 "prompt_eval_count": 147,
447 "prompt_eval_duration": 684637583,
448 "eval_count": 16,
449 "eval_duration": 302561917,
450 });
451
452 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
453 match result.message {
454 ChatMessage::Assistant {
455 content,
456 tool_calls,
457 images: _,
458 thinking,
459 } => {
460 assert!(content.is_empty());
461 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
462 assert!(thinking.is_none());
463 }
464 _ => panic!("Deserialized wrong role"),
465 }
466 }
467
468 #[test]
469 fn parse_show_model() {
470 let response = serde_json::json!({
471 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
472 "details": {
473 "parent_model": "",
474 "format": "gguf",
475 "family": "llama",
476 "families": ["llama"],
477 "parameter_size": "3.2B",
478 "quantization_level": "Q4_K_M"
479 },
480 "model_info": {
481 "general.architecture": "llama",
482 "general.basename": "Llama-3.2",
483 "general.file_type": 15,
484 "general.finetune": "Instruct",
485 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
486 "general.parameter_count": 3212749888u64,
487 "general.quantization_version": 2,
488 "general.size_label": "3B",
489 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
490 "general.type": "model",
491 "llama.attention.head_count": 24,
492 "llama.attention.head_count_kv": 8,
493 "llama.attention.key_length": 128,
494 "llama.attention.layer_norm_rms_epsilon": 0.00001,
495 "llama.attention.value_length": 128,
496 "llama.block_count": 28,
497 "llama.context_length": 131072,
498 "llama.embedding_length": 3072,
499 "llama.feed_forward_length": 8192,
500 "llama.rope.dimension_count": 128,
501 "llama.rope.freq_base": 500000,
502 "llama.vocab_size": 128256,
503 "tokenizer.ggml.bos_token_id": 128000,
504 "tokenizer.ggml.eos_token_id": 128009,
505 "tokenizer.ggml.merges": null,
506 "tokenizer.ggml.model": "gpt2",
507 "tokenizer.ggml.pre": "llama-bpe",
508 "tokenizer.ggml.token_type": null,
509 "tokenizer.ggml.tokens": null
510 },
511 "tensors": [
512 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
513 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
514 ],
515 "capabilities": ["completion", "tools"],
516 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
517 });
518
519 let result: ModelShow = serde_json::from_value(response).unwrap();
520 assert!(result.supports_tools());
521 assert!(result.capabilities.contains(&"tools".to_string()));
522 assert!(result.capabilities.contains(&"completion".to_string()));
523 }
524
525 #[test]
526 fn serialize_chat_request_with_images() {
527 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
528
529 let request = ChatRequest {
530 model: "llava".to_string(),
531 messages: vec![ChatMessage::User {
532 content: "What do you see in this image?".to_string(),
533 images: Some(vec![base64_image.to_string()]),
534 }],
535 stream: false,
536 keep_alive: KeepAlive::default(),
537 options: None,
538 think: None,
539 tools: vec![],
540 };
541
542 let serialized = serde_json::to_string(&request).unwrap();
543 assert!(serialized.contains("images"));
544 assert!(serialized.contains(base64_image));
545 }
546
547 #[test]
548 fn serialize_chat_request_without_images() {
549 let request = ChatRequest {
550 model: "llama3.2".to_string(),
551 messages: vec![ChatMessage::User {
552 content: "Hello, world!".to_string(),
553 images: None,
554 }],
555 stream: false,
556 keep_alive: KeepAlive::default(),
557 options: None,
558 think: None,
559 tools: vec![],
560 };
561
562 let serialized = serde_json::to_string(&request).unwrap();
563 assert!(!serialized.contains("images"));
564 }
565
566 #[test]
567 fn test_json_format_with_images() {
568 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
569
570 let request = ChatRequest {
571 model: "llava".to_string(),
572 messages: vec![ChatMessage::User {
573 content: "What do you see?".to_string(),
574 images: Some(vec![base64_image.to_string()]),
575 }],
576 stream: false,
577 keep_alive: KeepAlive::default(),
578 options: None,
579 think: None,
580 tools: vec![],
581 };
582
583 let serialized = serde_json::to_string(&request).unwrap();
584
585 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
586 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
587 assert_eq!(message_images.len(), 1);
588 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
589 }
590}