1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::{value::RawValue, Value};
8use std::{convert::TryFrom, sync::Arc, time::Duration};
9
10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 _ => Err(anyhow!("invalid role '{value}'")),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 }
40 }
41}
42
43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
44#[serde(untagged)]
45pub enum KeepAlive {
46 /// Keep model alive for N seconds
47 Seconds(isize),
48 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
49 Duration(String),
50}
51
52impl KeepAlive {
53 /// Keep model alive until a new model is loaded or until Ollama shuts down
54 fn indefinite() -> Self {
55 Self::Seconds(-1)
56 }
57}
58
59impl Default for KeepAlive {
60 fn default() -> Self {
61 Self::indefinite()
62 }
63}
64
65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
67pub struct Model {
68 pub name: String,
69 pub max_tokens: usize,
70 pub keep_alive: Option<KeepAlive>,
71}
72
73// This could be dynamically retrieved via the API (1 call per model)
74// curl -s http://localhost:11434/api/show -d '{"model": "llama3.1:latest"}' | jq '.model_info."llama.context_length"'
75fn get_max_tokens(name: &str) -> usize {
76 match name {
77 "dolphin-llama3:8b-256k" => 262144, // 256K
78 _ => match name.split(':').next().unwrap() {
79 "mistral-nemo" => 1024000, // 1M
80 "deepseek-coder-v2" => 163840, // 160K
81 "llama3.1" | "phi3" | "command-r" | "command-r-plus" => 131072, // 128K
82 "codeqwen" => 65536, // 64K
83 "mistral" | "mistral-large" | "dolphin-mistral" | "codestral" // 32K
84 | "mistral-openorca" | "dolphin-mixtral" | "mixstral" | "llava"
85 | "qwen" | "qwen2" | "wizardlm2" | "wizard-math" => 32768,
86 "codellama" | "stable-code" | "deepseek-coder" | "starcoder2" // 16K
87 | "wizardcoder" => 16384,
88 "llama3" | "gemma2" | "gemma" | "codegemma" | "dolphin-llama3" // 8K
89 | "llava-llama3" | "starcoder" | "openchat" | "aya" => 8192,
90 "llama2" | "yi" | "llama2-chinese" | "vicuna" | "nous-hermes2" // 4K
91 | "stablelm2" => 4096,
92 "phi" | "orca-mini" | "tinyllama" | "granite-code" => 2048, // 2K
93 _ => 2048, // 2K (default)
94 },
95 }
96}
97
98impl Model {
99 pub fn new(name: &str) -> Self {
100 Self {
101 name: name.to_owned(),
102 max_tokens: get_max_tokens(name),
103 keep_alive: Some(KeepAlive::indefinite()),
104 }
105 }
106
107 pub fn id(&self) -> &str {
108 &self.name
109 }
110
111 pub fn display_name(&self) -> &str {
112 &self.name
113 }
114
115 pub fn max_token_count(&self) -> usize {
116 self.max_tokens
117 }
118}
119
120#[derive(Serialize, Deserialize, Debug)]
121#[serde(tag = "role", rename_all = "lowercase")]
122pub enum ChatMessage {
123 Assistant {
124 content: String,
125 tool_calls: Option<Vec<OllamaToolCall>>,
126 },
127 User {
128 content: String,
129 },
130 System {
131 content: String,
132 },
133}
134
135#[derive(Serialize, Deserialize, Debug)]
136#[serde(rename_all = "lowercase")]
137pub enum OllamaToolCall {
138 Function(OllamaFunctionCall),
139}
140
141#[derive(Serialize, Deserialize, Debug)]
142pub struct OllamaFunctionCall {
143 pub name: String,
144 pub arguments: Box<RawValue>,
145}
146
147#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
148pub struct OllamaFunctionTool {
149 pub name: String,
150 pub description: Option<String>,
151 pub parameters: Option<Value>,
152}
153
154#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
155#[serde(tag = "type", rename_all = "lowercase")]
156pub enum OllamaTool {
157 Function { function: OllamaFunctionTool },
158}
159
160#[derive(Serialize, Debug)]
161pub struct ChatRequest {
162 pub model: String,
163 pub messages: Vec<ChatMessage>,
164 pub stream: bool,
165 pub keep_alive: KeepAlive,
166 pub options: Option<ChatOptions>,
167 pub tools: Vec<OllamaTool>,
168}
169
170impl ChatRequest {
171 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
172 self.stream = false;
173 self.tools = tools;
174 self
175 }
176}
177
178// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
179#[derive(Serialize, Default, Debug)]
180pub struct ChatOptions {
181 pub num_ctx: Option<usize>,
182 pub num_predict: Option<isize>,
183 pub stop: Option<Vec<String>>,
184 pub temperature: Option<f32>,
185 pub top_p: Option<f32>,
186}
187
188#[derive(Deserialize, Debug)]
189pub struct ChatResponseDelta {
190 #[allow(unused)]
191 pub model: String,
192 #[allow(unused)]
193 pub created_at: String,
194 pub message: ChatMessage,
195 #[allow(unused)]
196 pub done_reason: Option<String>,
197 #[allow(unused)]
198 pub done: bool,
199}
200
201#[derive(Serialize, Deserialize)]
202pub struct LocalModelsResponse {
203 pub models: Vec<LocalModelListing>,
204}
205
206#[derive(Serialize, Deserialize)]
207pub struct LocalModelListing {
208 pub name: String,
209 pub modified_at: String,
210 pub size: u64,
211 pub digest: String,
212 pub details: ModelDetails,
213}
214
215#[derive(Serialize, Deserialize)]
216pub struct LocalModel {
217 pub modelfile: String,
218 pub parameters: String,
219 pub template: String,
220 pub details: ModelDetails,
221}
222
223#[derive(Serialize, Deserialize)]
224pub struct ModelDetails {
225 pub format: String,
226 pub family: String,
227 pub families: Option<Vec<String>>,
228 pub parameter_size: String,
229 pub quantization_level: String,
230}
231
232pub async fn complete(
233 client: &dyn HttpClient,
234 api_url: &str,
235 request: ChatRequest,
236) -> Result<ChatResponseDelta> {
237 let uri = format!("{api_url}/api/chat");
238 let request_builder = HttpRequest::builder()
239 .method(Method::POST)
240 .uri(uri)
241 .header("Content-Type", "application/json");
242
243 let serialized_request = serde_json::to_string(&request)?;
244 let request = request_builder.body(AsyncBody::from(serialized_request))?;
245
246 let mut response = client.send(request).await?;
247 if response.status().is_success() {
248 let mut body = Vec::new();
249 response.body_mut().read_to_end(&mut body).await?;
250 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
251 Ok(response_message)
252 } else {
253 let mut body = Vec::new();
254 response.body_mut().read_to_end(&mut body).await?;
255 let body_str = std::str::from_utf8(&body)?;
256 Err(anyhow!(
257 "Failed to connect to API: {} {}",
258 response.status(),
259 body_str
260 ))
261 }
262}
263
264pub async fn stream_chat_completion(
265 client: &dyn HttpClient,
266 api_url: &str,
267 request: ChatRequest,
268 low_speed_timeout: Option<Duration>,
269) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
270 let uri = format!("{api_url}/api/chat");
271 let mut request_builder = HttpRequest::builder()
272 .method(Method::POST)
273 .uri(uri)
274 .header("Content-Type", "application/json");
275
276 if let Some(low_speed_timeout) = low_speed_timeout {
277 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
278 };
279
280 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
281 let mut response = client.send(request).await?;
282 if response.status().is_success() {
283 let reader = BufReader::new(response.into_body());
284
285 Ok(reader
286 .lines()
287 .filter_map(|line| async move {
288 match line {
289 Ok(line) => {
290 Some(serde_json::from_str(&line).context("Unable to parse chat response"))
291 }
292 Err(e) => Some(Err(e.into())),
293 }
294 })
295 .boxed())
296 } else {
297 let mut body = String::new();
298 response.body_mut().read_to_string(&mut body).await?;
299
300 Err(anyhow!(
301 "Failed to connect to Ollama API: {} {}",
302 response.status(),
303 body,
304 ))
305 }
306}
307
308pub async fn get_models(
309 client: &dyn HttpClient,
310 api_url: &str,
311 low_speed_timeout: Option<Duration>,
312) -> Result<Vec<LocalModelListing>> {
313 let uri = format!("{api_url}/api/tags");
314 let mut request_builder = HttpRequest::builder()
315 .method(Method::GET)
316 .uri(uri)
317 .header("Accept", "application/json");
318
319 if let Some(low_speed_timeout) = low_speed_timeout {
320 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
321 };
322
323 let request = request_builder.body(AsyncBody::default())?;
324
325 let mut response = client.send(request).await?;
326
327 let mut body = String::new();
328 response.body_mut().read_to_string(&mut body).await?;
329
330 if response.status().is_success() {
331 let response: LocalModelsResponse =
332 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
333
334 Ok(response.models)
335 } else {
336 Err(anyhow!(
337 "Failed to connect to Ollama API: {} {}",
338 response.status(),
339 body,
340 ))
341 }
342}
343
344/// Sends an empty request to Ollama to trigger loading the model
345pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
346 let uri = format!("{api_url}/api/generate");
347 let request = HttpRequest::builder()
348 .method(Method::POST)
349 .uri(uri)
350 .header("Content-Type", "application/json")
351 .body(AsyncBody::from(serde_json::to_string(
352 &serde_json::json!({
353 "model": model,
354 "keep_alive": "15m",
355 }),
356 )?))?;
357
358 let mut response = match client.send(request).await {
359 Ok(response) => response,
360 Err(err) => {
361 // Be ok with a timeout during preload of the model
362 if err.is_timeout() {
363 return Ok(());
364 } else {
365 return Err(err.into());
366 }
367 }
368 };
369
370 if response.status().is_success() {
371 Ok(())
372 } else {
373 let mut body = String::new();
374 response.body_mut().read_to_string(&mut body).await?;
375
376 Err(anyhow!(
377 "Failed to connect to Ollama API: {} {}",
378 response.status(),
379 body,
380 ))
381 }
382}