1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{sync::Arc, time::Duration};
7
8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(untagged)]
13pub enum KeepAlive {
14 /// Keep model alive for N seconds
15 Seconds(isize),
16 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
17 Duration(String),
18}
19
20impl KeepAlive {
21 /// Keep model alive until a new model is loaded or until Ollama shuts down
22 fn indefinite() -> Self {
23 Self::Seconds(-1)
24 }
25}
26
27impl Default for KeepAlive {
28 fn default() -> Self {
29 Self::indefinite()
30 }
31}
32
33#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
35pub struct Model {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: usize,
39 pub keep_alive: Option<KeepAlive>,
40 pub supports_tools: Option<bool>,
41}
42
43fn get_max_tokens(name: &str) -> usize {
44 /// Default context length for unknown models.
45 const DEFAULT_TOKENS: usize = 2048;
46 /// Magic number. Lets many Ollama models work with ~16GB of ram.
47 const MAXIMUM_TOKENS: usize = 16384;
48
49 match name.split(':').next().unwrap() {
50 "phi" | "tinyllama" | "granite-code" => 2048,
51 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
52 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
53 "codellama" | "starcoder2" => 16384,
54 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
55 | "dolphin-mixtral" => 32768,
56 "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
57 | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" => {
58 128000
59 }
60 _ => DEFAULT_TOKENS,
61 }
62 .clamp(1, MAXIMUM_TOKENS)
63}
64
65impl Model {
66 pub fn new(
67 name: &str,
68 display_name: Option<&str>,
69 max_tokens: Option<usize>,
70 supports_tools: Option<bool>,
71 ) -> Self {
72 Self {
73 name: name.to_owned(),
74 display_name: display_name
75 .map(ToString::to_string)
76 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
77 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
78 keep_alive: Some(KeepAlive::indefinite()),
79 supports_tools,
80 }
81 }
82
83 pub fn id(&self) -> &str {
84 &self.name
85 }
86
87 pub fn display_name(&self) -> &str {
88 self.display_name.as_ref().unwrap_or(&self.name)
89 }
90
91 pub fn max_token_count(&self) -> usize {
92 self.max_tokens
93 }
94}
95
96#[derive(Serialize, Deserialize, Debug)]
97#[serde(tag = "role", rename_all = "lowercase")]
98pub enum ChatMessage {
99 Assistant {
100 content: String,
101 tool_calls: Option<Vec<OllamaToolCall>>,
102 },
103 User {
104 content: String,
105 },
106 System {
107 content: String,
108 },
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112#[serde(rename_all = "lowercase")]
113pub enum OllamaToolCall {
114 Function(OllamaFunctionCall),
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118pub struct OllamaFunctionCall {
119 pub name: String,
120 pub arguments: Value,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124pub struct OllamaFunctionTool {
125 pub name: String,
126 pub description: Option<String>,
127 pub parameters: Option<Value>,
128}
129
130#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
131#[serde(tag = "type", rename_all = "lowercase")]
132pub enum OllamaTool {
133 Function { function: OllamaFunctionTool },
134}
135
136#[derive(Serialize, Debug)]
137pub struct ChatRequest {
138 pub model: String,
139 pub messages: Vec<ChatMessage>,
140 pub stream: bool,
141 pub keep_alive: KeepAlive,
142 pub options: Option<ChatOptions>,
143 pub tools: Vec<OllamaTool>,
144}
145
146impl ChatRequest {
147 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
148 self.stream = false;
149 self.tools = tools;
150 self
151 }
152}
153
154// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
155#[derive(Serialize, Default, Debug)]
156pub struct ChatOptions {
157 pub num_ctx: Option<usize>,
158 pub num_predict: Option<isize>,
159 pub stop: Option<Vec<String>>,
160 pub temperature: Option<f32>,
161 pub top_p: Option<f32>,
162}
163
164#[derive(Deserialize, Debug)]
165pub struct ChatResponseDelta {
166 #[allow(unused)]
167 pub model: String,
168 #[allow(unused)]
169 pub created_at: String,
170 pub message: ChatMessage,
171 #[allow(unused)]
172 pub done_reason: Option<String>,
173 #[allow(unused)]
174 pub done: bool,
175}
176
177#[derive(Serialize, Deserialize)]
178pub struct LocalModelsResponse {
179 pub models: Vec<LocalModelListing>,
180}
181
182#[derive(Serialize, Deserialize)]
183pub struct LocalModelListing {
184 pub name: String,
185 pub modified_at: String,
186 pub size: u64,
187 pub digest: String,
188 pub details: ModelDetails,
189}
190
191#[derive(Serialize, Deserialize)]
192pub struct LocalModel {
193 pub modelfile: String,
194 pub parameters: String,
195 pub template: String,
196 pub details: ModelDetails,
197}
198
199#[derive(Serialize, Deserialize)]
200pub struct ModelDetails {
201 pub format: String,
202 pub family: String,
203 pub families: Option<Vec<String>>,
204 pub parameter_size: String,
205 pub quantization_level: String,
206}
207
208#[derive(Deserialize, Debug)]
209pub struct ModelShow {
210 #[serde(default)]
211 pub capabilities: Vec<String>,
212}
213
214impl ModelShow {
215 pub fn supports_tools(&self) -> bool {
216 // .contains expects &String, which would require an additional allocation
217 self.capabilities.iter().any(|v| v == "tools")
218 }
219}
220
221pub async fn complete(
222 client: &dyn HttpClient,
223 api_url: &str,
224 request: ChatRequest,
225) -> Result<ChatResponseDelta> {
226 let uri = format!("{api_url}/api/chat");
227 let request_builder = HttpRequest::builder()
228 .method(Method::POST)
229 .uri(uri)
230 .header("Content-Type", "application/json");
231
232 let serialized_request = serde_json::to_string(&request)?;
233 let request = request_builder.body(AsyncBody::from(serialized_request))?;
234
235 let mut response = client.send(request).await?;
236
237 let mut body = Vec::new();
238 response.body_mut().read_to_end(&mut body).await?;
239
240 if response.status().is_success() {
241 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
242 Ok(response_message)
243 } else {
244 let body_str = std::str::from_utf8(&body)?;
245 Err(anyhow!(
246 "Failed to connect to API: {} {}",
247 response.status(),
248 body_str
249 ))
250 }
251}
252
253pub async fn stream_chat_completion(
254 client: &dyn HttpClient,
255 api_url: &str,
256 request: ChatRequest,
257) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
258 let uri = format!("{api_url}/api/chat");
259 let request_builder = http::Request::builder()
260 .method(Method::POST)
261 .uri(uri)
262 .header("Content-Type", "application/json");
263
264 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
265 let mut response = client.send(request).await?;
266 if response.status().is_success() {
267 let reader = BufReader::new(response.into_body());
268
269 Ok(reader
270 .lines()
271 .map(|line| match line {
272 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
273 Err(e) => Err(e.into()),
274 })
275 .boxed())
276 } else {
277 let mut body = String::new();
278 response.body_mut().read_to_string(&mut body).await?;
279
280 Err(anyhow!(
281 "Failed to connect to Ollama API: {} {}",
282 response.status(),
283 body,
284 ))
285 }
286}
287
288pub async fn get_models(
289 client: &dyn HttpClient,
290 api_url: &str,
291 _: Option<Duration>,
292) -> Result<Vec<LocalModelListing>> {
293 let uri = format!("{api_url}/api/tags");
294 let request_builder = HttpRequest::builder()
295 .method(Method::GET)
296 .uri(uri)
297 .header("Accept", "application/json");
298
299 let request = request_builder.body(AsyncBody::default())?;
300
301 let mut response = client.send(request).await?;
302
303 let mut body = String::new();
304 response.body_mut().read_to_string(&mut body).await?;
305
306 if response.status().is_success() {
307 let response: LocalModelsResponse =
308 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
309
310 Ok(response.models)
311 } else {
312 Err(anyhow!(
313 "Failed to connect to Ollama API: {} {}",
314 response.status(),
315 body,
316 ))
317 }
318}
319
320/// Fetch details of a model, used to determine model capabilities
321pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
322 let uri = format!("{api_url}/api/show");
323 let request = HttpRequest::builder()
324 .method(Method::POST)
325 .uri(uri)
326 .header("Content-Type", "application/json")
327 .body(AsyncBody::from(
328 serde_json::json!({ "model": model }).to_string(),
329 ))?;
330
331 let mut response = client.send(request).await?;
332 let mut body = String::new();
333 response.body_mut().read_to_string(&mut body).await?;
334
335 if response.status().is_success() {
336 let details: ModelShow = serde_json::from_str(body.as_str())?;
337 Ok(details)
338 } else {
339 Err(anyhow!(
340 "Failed to connect to Ollama API: {} {}",
341 response.status(),
342 body,
343 ))
344 }
345}
346
347/// Sends an empty request to Ollama to trigger loading the model
348pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
349 let uri = format!("{api_url}/api/generate");
350 let request = HttpRequest::builder()
351 .method(Method::POST)
352 .uri(uri)
353 .header("Content-Type", "application/json")
354 .body(AsyncBody::from(
355 serde_json::json!({
356 "model": model,
357 "keep_alive": "15m",
358 })
359 .to_string(),
360 ))?;
361
362 let mut response = client.send(request).await?;
363
364 if response.status().is_success() {
365 Ok(())
366 } else {
367 let mut body = String::new();
368 response.body_mut().read_to_string(&mut body).await?;
369
370 Err(anyhow!(
371 "Failed to connect to Ollama API: {} {}",
372 response.status(),
373 body,
374 ))
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn parse_completion() {
384 let response = serde_json::json!({
385 "model": "llama3.2",
386 "created_at": "2023-12-12T14:13:43.416799Z",
387 "message": {
388 "role": "assistant",
389 "content": "Hello! How are you today?"
390 },
391 "done": true,
392 "total_duration": 5191566416u64,
393 "load_duration": 2154458,
394 "prompt_eval_count": 26,
395 "prompt_eval_duration": 383809000,
396 "eval_count": 298,
397 "eval_duration": 4799921000u64
398 });
399 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
400 }
401
402 #[test]
403 fn parse_streaming_completion() {
404 let partial = serde_json::json!({
405 "model": "llama3.2",
406 "created_at": "2023-08-04T08:52:19.385406455-07:00",
407 "message": {
408 "role": "assistant",
409 "content": "The",
410 "images": null
411 },
412 "done": false
413 });
414
415 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
416
417 let last = serde_json::json!({
418 "model": "llama3.2",
419 "created_at": "2023-08-04T19:22:45.499127Z",
420 "message": {
421 "role": "assistant",
422 "content": ""
423 },
424 "done": true,
425 "total_duration": 4883583458u64,
426 "load_duration": 1334875,
427 "prompt_eval_count": 26,
428 "prompt_eval_duration": 342546000,
429 "eval_count": 282,
430 "eval_duration": 4535599000u64
431 });
432
433 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
434 }
435
436 #[test]
437 fn parse_tool_call() {
438 let response = serde_json::json!({
439 "model": "llama3.2:3b",
440 "created_at": "2025-04-28T20:02:02.140489Z",
441 "message": {
442 "role": "assistant",
443 "content": "",
444 "tool_calls": [
445 {
446 "function": {
447 "name": "weather",
448 "arguments": {
449 "city": "london",
450 }
451 }
452 }
453 ]
454 },
455 "done_reason": "stop",
456 "done": true,
457 "total_duration": 2758629166u64,
458 "load_duration": 1770059875,
459 "prompt_eval_count": 147,
460 "prompt_eval_duration": 684637583,
461 "eval_count": 16,
462 "eval_duration": 302561917,
463 });
464
465 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
466 match result.message {
467 ChatMessage::Assistant {
468 content,
469 tool_calls,
470 } => {
471 assert!(content.is_empty());
472 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
473 }
474 _ => panic!("Deserialized wrong role"),
475 }
476 }
477
478 #[test]
479 fn parse_show_model() {
480 let response = serde_json::json!({
481 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
482 "details": {
483 "parent_model": "",
484 "format": "gguf",
485 "family": "llama",
486 "families": ["llama"],
487 "parameter_size": "3.2B",
488 "quantization_level": "Q4_K_M"
489 },
490 "model_info": {
491 "general.architecture": "llama",
492 "general.basename": "Llama-3.2",
493 "general.file_type": 15,
494 "general.finetune": "Instruct",
495 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
496 "general.parameter_count": 3212749888u64,
497 "general.quantization_version": 2,
498 "general.size_label": "3B",
499 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
500 "general.type": "model",
501 "llama.attention.head_count": 24,
502 "llama.attention.head_count_kv": 8,
503 "llama.attention.key_length": 128,
504 "llama.attention.layer_norm_rms_epsilon": 0.00001,
505 "llama.attention.value_length": 128,
506 "llama.block_count": 28,
507 "llama.context_length": 131072,
508 "llama.embedding_length": 3072,
509 "llama.feed_forward_length": 8192,
510 "llama.rope.dimension_count": 128,
511 "llama.rope.freq_base": 500000,
512 "llama.vocab_size": 128256,
513 "tokenizer.ggml.bos_token_id": 128000,
514 "tokenizer.ggml.eos_token_id": 128009,
515 "tokenizer.ggml.merges": null,
516 "tokenizer.ggml.model": "gpt2",
517 "tokenizer.ggml.pre": "llama-bpe",
518 "tokenizer.ggml.token_type": null,
519 "tokenizer.ggml.tokens": null
520 },
521 "tensors": [
522 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
523 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
524 ],
525 "capabilities": ["completion", "tools"],
526 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
527 });
528
529 let result: ModelShow = serde_json::from_value(response).unwrap();
530 assert!(result.supports_tools());
531 assert!(result.capabilities.contains(&"tools".to_string()));
532 assert!(result.capabilities.contains(&"completion".to_string()));
533 }
534}