1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::KeepAlive;
7use std::time::Duration;
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
13pub struct Model {
14 pub name: String,
15 pub display_name: Option<String>,
16 pub max_tokens: u64,
17 pub keep_alive: Option<KeepAlive>,
18 pub supports_tools: Option<bool>,
19 pub supports_vision: Option<bool>,
20 pub supports_thinking: Option<bool>,
21}
22
23fn get_max_tokens(name: &str) -> u64 {
24 /// Default context length for unknown models.
25 const DEFAULT_TOKENS: u64 = 4096;
26 /// Magic number. Lets many Ollama models work with ~16GB of ram.
27 /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k
28 const MAXIMUM_TOKENS: u64 = 16384;
29
30 match name.split(':').next().unwrap() {
31 "granite-code" | "phi" | "tinyllama" => 2048,
32 "llama2" | "stablelm2" | "vicuna" | "yi" => 4096,
33 "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192,
34 "codellama" | "starcoder2" => 16384,
35 "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral"
36 | "qwen2" | "qwen2.5-coder" => 32768,
37 "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3"
38 | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2"
39 | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000,
40 "qwen3-coder" => 256000,
41 _ => DEFAULT_TOKENS,
42 }
43 .clamp(1, MAXIMUM_TOKENS)
44}
45
46impl Model {
47 pub fn new(
48 name: &str,
49 display_name: Option<&str>,
50 max_tokens: Option<u64>,
51 supports_tools: Option<bool>,
52 supports_vision: Option<bool>,
53 supports_thinking: Option<bool>,
54 ) -> Self {
55 Self {
56 name: name.to_owned(),
57 display_name: display_name
58 .map(ToString::to_string)
59 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
60 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
61 keep_alive: Some(KeepAlive::indefinite()),
62 supports_tools,
63 supports_vision,
64 supports_thinking,
65 }
66 }
67
68 pub fn id(&self) -> &str {
69 &self.name
70 }
71
72 pub fn display_name(&self) -> &str {
73 self.display_name.as_ref().unwrap_or(&self.name)
74 }
75
76 pub fn max_token_count(&self) -> u64 {
77 self.max_tokens
78 }
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82#[serde(tag = "role", rename_all = "lowercase")]
83pub enum ChatMessage {
84 Assistant {
85 content: String,
86 tool_calls: Option<Vec<OllamaToolCall>>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 images: Option<Vec<String>>,
89 thinking: Option<String>,
90 },
91 User {
92 content: String,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 images: Option<Vec<String>>,
95 },
96 System {
97 content: String,
98 },
99 Tool {
100 tool_name: String,
101 content: String,
102 },
103}
104
105#[derive(Serialize, Deserialize, Debug)]
106#[serde(rename_all = "lowercase")]
107pub enum OllamaToolCall {
108 Function(OllamaFunctionCall),
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112pub struct OllamaFunctionCall {
113 pub name: String,
114 pub arguments: Value,
115}
116
117#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
118pub struct OllamaFunctionTool {
119 pub name: String,
120 pub description: Option<String>,
121 pub parameters: Option<Value>,
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125#[serde(tag = "type", rename_all = "lowercase")]
126pub enum OllamaTool {
127 Function { function: OllamaFunctionTool },
128}
129
130#[derive(Serialize, Debug)]
131pub struct ChatRequest {
132 pub model: String,
133 pub messages: Vec<ChatMessage>,
134 pub stream: bool,
135 pub keep_alive: KeepAlive,
136 pub options: Option<ChatOptions>,
137 pub tools: Vec<OllamaTool>,
138 pub think: Option<bool>,
139}
140
141impl ChatRequest {
142 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
143 self.stream = false;
144 self.tools = tools;
145 self
146 }
147}
148
149// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
150#[derive(Serialize, Default, Debug)]
151pub struct ChatOptions {
152 pub num_ctx: Option<u64>,
153 pub num_predict: Option<isize>,
154 pub stop: Option<Vec<String>>,
155 pub temperature: Option<f32>,
156 pub top_p: Option<f32>,
157}
158
159#[derive(Deserialize, Debug)]
160pub struct ChatResponseDelta {
161 #[allow(unused)]
162 pub model: String,
163 #[allow(unused)]
164 pub created_at: String,
165 pub message: ChatMessage,
166 #[allow(unused)]
167 pub done_reason: Option<String>,
168 #[allow(unused)]
169 pub done: bool,
170 pub prompt_eval_count: Option<u64>,
171 pub eval_count: Option<u64>,
172}
173
174#[derive(Serialize, Deserialize)]
175pub struct LocalModelsResponse {
176 pub models: Vec<LocalModelListing>,
177}
178
179#[derive(Serialize, Deserialize)]
180pub struct LocalModelListing {
181 pub name: String,
182 pub modified_at: String,
183 pub size: u64,
184 pub digest: String,
185 pub details: ModelDetails,
186}
187
188#[derive(Serialize, Deserialize)]
189pub struct LocalModel {
190 pub modelfile: String,
191 pub parameters: String,
192 pub template: String,
193 pub details: ModelDetails,
194}
195
196#[derive(Serialize, Deserialize)]
197pub struct ModelDetails {
198 pub format: String,
199 pub family: String,
200 pub families: Option<Vec<String>>,
201 pub parameter_size: String,
202 pub quantization_level: String,
203}
204
205#[derive(Deserialize, Debug)]
206pub struct ModelShow {
207 #[serde(default)]
208 pub capabilities: Vec<String>,
209}
210
211impl ModelShow {
212 pub fn supports_tools(&self) -> bool {
213 // .contains expects &String, which would require an additional allocation
214 self.capabilities.iter().any(|v| v == "tools")
215 }
216
217 pub fn supports_vision(&self) -> bool {
218 self.capabilities.iter().any(|v| v == "vision")
219 }
220
221 pub fn supports_thinking(&self) -> bool {
222 self.capabilities.iter().any(|v| v == "thinking")
223 }
224}
225
226pub async fn complete(
227 client: &dyn HttpClient,
228 api_url: &str,
229 request: ChatRequest,
230) -> Result<ChatResponseDelta> {
231 let uri = format!("{api_url}/api/chat");
232 let request_builder = HttpRequest::builder()
233 .method(Method::POST)
234 .uri(uri)
235 .header("Content-Type", "application/json");
236
237 let serialized_request = serde_json::to_string(&request)?;
238 let request = request_builder.body(AsyncBody::from(serialized_request))?;
239
240 let mut response = client.send(request).await?;
241
242 let mut body = Vec::new();
243 response.body_mut().read_to_end(&mut body).await?;
244
245 if response.status().is_success() {
246 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
247 Ok(response_message)
248 } else {
249 let body_str = std::str::from_utf8(&body)?;
250 anyhow::bail!(
251 "Failed to connect to API: {} {}",
252 response.status(),
253 body_str
254 );
255 }
256}
257
258pub async fn stream_chat_completion(
259 client: &dyn HttpClient,
260 api_url: &str,
261 api_key: Option<&str>,
262 request: ChatRequest,
263) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
264 let uri = format!("{api_url}/api/chat");
265 let request = HttpRequest::builder()
266 .method(Method::POST)
267 .uri(uri)
268 .header("Content-Type", "application/json")
269 .when_some(api_key, |builder, api_key| {
270 builder.header("Authorization", format!("Bearer {api_key}"))
271 })
272 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
273
274 let mut response = client.send(request).await?;
275 if response.status().is_success() {
276 let reader = BufReader::new(response.into_body());
277
278 Ok(reader
279 .lines()
280 .map(|line| match line {
281 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
282 Err(e) => Err(e.into()),
283 })
284 .boxed())
285 } else {
286 let mut body = String::new();
287 response.body_mut().read_to_string(&mut body).await?;
288 anyhow::bail!(
289 "Failed to connect to Ollama API: {} {}",
290 response.status(),
291 body,
292 );
293 }
294}
295
296pub async fn get_models(
297 client: &dyn HttpClient,
298 api_url: &str,
299 api_key: Option<&str>,
300 _: Option<Duration>,
301) -> Result<Vec<LocalModelListing>> {
302 let uri = format!("{api_url}/api/tags");
303 let request = HttpRequest::builder()
304 .method(Method::GET)
305 .uri(uri)
306 .header("Accept", "application/json")
307 .when_some(api_key, |builder, api_key| {
308 builder.header("Authorization", format!("Bearer {api_key}"))
309 })
310 .body(AsyncBody::default())?;
311
312 let mut response = client.send(request).await?;
313
314 let mut body = String::new();
315 response.body_mut().read_to_string(&mut body).await?;
316
317 anyhow::ensure!(
318 response.status().is_success(),
319 "Failed to connect to Ollama API: {} {}",
320 response.status(),
321 body,
322 );
323 let response: LocalModelsResponse =
324 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
325 Ok(response.models)
326}
327
328/// Fetch details of a model, used to determine model capabilities
329pub async fn show_model(
330 client: &dyn HttpClient,
331 api_url: &str,
332 api_key: Option<&str>,
333 model: &str,
334) -> Result<ModelShow> {
335 let uri = format!("{api_url}/api/show");
336 let request = HttpRequest::builder()
337 .method(Method::POST)
338 .uri(uri)
339 .header("Content-Type", "application/json")
340 .when_some(api_key, |builder, api_key| {
341 builder.header("Authorization", format!("Bearer {api_key}"))
342 })
343 .body(AsyncBody::from(
344 serde_json::json!({ "model": model }).to_string(),
345 ))?;
346
347 let mut response = client.send(request).await?;
348 let mut body = String::new();
349 response.body_mut().read_to_string(&mut body).await?;
350
351 anyhow::ensure!(
352 response.status().is_success(),
353 "Failed to connect to Ollama API: {} {}",
354 response.status(),
355 body,
356 );
357 let details: ModelShow = serde_json::from_str(body.as_str())?;
358 Ok(details)
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn parse_completion() {
367 let response = serde_json::json!({
368 "model": "llama3.2",
369 "created_at": "2023-12-12T14:13:43.416799Z",
370 "message": {
371 "role": "assistant",
372 "content": "Hello! How are you today?"
373 },
374 "done": true,
375 "total_duration": 5191566416u64,
376 "load_duration": 2154458,
377 "prompt_eval_count": 26,
378 "prompt_eval_duration": 383809000,
379 "eval_count": 298,
380 "eval_duration": 4799921000u64
381 });
382 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
383 }
384
385 #[test]
386 fn parse_streaming_completion() {
387 let partial = serde_json::json!({
388 "model": "llama3.2",
389 "created_at": "2023-08-04T08:52:19.385406455-07:00",
390 "message": {
391 "role": "assistant",
392 "content": "The",
393 "images": null
394 },
395 "done": false
396 });
397
398 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
399
400 let last = serde_json::json!({
401 "model": "llama3.2",
402 "created_at": "2023-08-04T19:22:45.499127Z",
403 "message": {
404 "role": "assistant",
405 "content": ""
406 },
407 "done": true,
408 "total_duration": 4883583458u64,
409 "load_duration": 1334875,
410 "prompt_eval_count": 26,
411 "prompt_eval_duration": 342546000,
412 "eval_count": 282,
413 "eval_duration": 4535599000u64
414 });
415
416 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
417 }
418
419 #[test]
420 fn parse_tool_call() {
421 let response = serde_json::json!({
422 "model": "llama3.2:3b",
423 "created_at": "2025-04-28T20:02:02.140489Z",
424 "message": {
425 "role": "assistant",
426 "content": "",
427 "tool_calls": [
428 {
429 "function": {
430 "name": "weather",
431 "arguments": {
432 "city": "london",
433 }
434 }
435 }
436 ]
437 },
438 "done_reason": "stop",
439 "done": true,
440 "total_duration": 2758629166u64,
441 "load_duration": 1770059875,
442 "prompt_eval_count": 147,
443 "prompt_eval_duration": 684637583,
444 "eval_count": 16,
445 "eval_duration": 302561917,
446 });
447
448 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
449 match result.message {
450 ChatMessage::Assistant {
451 content,
452 tool_calls,
453 images: _,
454 thinking,
455 } => {
456 assert!(content.is_empty());
457 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
458 assert!(thinking.is_none());
459 }
460 _ => panic!("Deserialized wrong role"),
461 }
462 }
463
464 #[test]
465 fn parse_show_model() {
466 let response = serde_json::json!({
467 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
468 "details": {
469 "parent_model": "",
470 "format": "gguf",
471 "family": "llama",
472 "families": ["llama"],
473 "parameter_size": "3.2B",
474 "quantization_level": "Q4_K_M"
475 },
476 "model_info": {
477 "general.architecture": "llama",
478 "general.basename": "Llama-3.2",
479 "general.file_type": 15,
480 "general.finetune": "Instruct",
481 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
482 "general.parameter_count": 3212749888u64,
483 "general.quantization_version": 2,
484 "general.size_label": "3B",
485 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
486 "general.type": "model",
487 "llama.attention.head_count": 24,
488 "llama.attention.head_count_kv": 8,
489 "llama.attention.key_length": 128,
490 "llama.attention.layer_norm_rms_epsilon": 0.00001,
491 "llama.attention.value_length": 128,
492 "llama.block_count": 28,
493 "llama.context_length": 131072,
494 "llama.embedding_length": 3072,
495 "llama.feed_forward_length": 8192,
496 "llama.rope.dimension_count": 128,
497 "llama.rope.freq_base": 500000,
498 "llama.vocab_size": 128256,
499 "tokenizer.ggml.bos_token_id": 128000,
500 "tokenizer.ggml.eos_token_id": 128009,
501 "tokenizer.ggml.merges": null,
502 "tokenizer.ggml.model": "gpt2",
503 "tokenizer.ggml.pre": "llama-bpe",
504 "tokenizer.ggml.token_type": null,
505 "tokenizer.ggml.tokens": null
506 },
507 "tensors": [
508 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
509 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
510 ],
511 "capabilities": ["completion", "tools"],
512 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
513 });
514
515 let result: ModelShow = serde_json::from_value(response).unwrap();
516 assert!(result.supports_tools());
517 assert!(result.capabilities.contains(&"tools".to_string()));
518 assert!(result.capabilities.contains(&"completion".to_string()));
519 }
520
521 #[test]
522 fn serialize_chat_request_with_images() {
523 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
524
525 let request = ChatRequest {
526 model: "llava".to_string(),
527 messages: vec![ChatMessage::User {
528 content: "What do you see in this image?".to_string(),
529 images: Some(vec![base64_image.to_string()]),
530 }],
531 stream: false,
532 keep_alive: KeepAlive::default(),
533 options: None,
534 think: None,
535 tools: vec![],
536 };
537
538 let serialized = serde_json::to_string(&request).unwrap();
539 assert!(serialized.contains("images"));
540 assert!(serialized.contains(base64_image));
541 }
542
543 #[test]
544 fn serialize_chat_request_without_images() {
545 let request = ChatRequest {
546 model: "llama3.2".to_string(),
547 messages: vec![ChatMessage::User {
548 content: "Hello, world!".to_string(),
549 images: None,
550 }],
551 stream: false,
552 keep_alive: KeepAlive::default(),
553 options: None,
554 think: None,
555 tools: vec![],
556 };
557
558 let serialized = serde_json::to_string(&request).unwrap();
559 assert!(!serialized.contains("images"));
560 }
561
562 #[test]
563 fn test_json_format_with_images() {
564 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
565
566 let request = ChatRequest {
567 model: "llava".to_string(),
568 messages: vec![ChatMessage::User {
569 content: "What do you see?".to_string(),
570 images: Some(vec![base64_image.to_string()]),
571 }],
572 stream: false,
573 keep_alive: KeepAlive::default(),
574 options: None,
575 think: None,
576 tools: vec![],
577 };
578
579 let serialized = serde_json::to_string(&request).unwrap();
580
581 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
582 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
583 assert_eq!(message_images.len(), 1);
584 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
585 }
586}