1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::{value::RawValue, Value};
8use std::{convert::TryFrom, sync::Arc, time::Duration};
9
10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 _ => Err(anyhow!("invalid role '{value}'")),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 }
40 }
41}
42
43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
44#[serde(untagged)]
45pub enum KeepAlive {
46 /// Keep model alive for N seconds
47 Seconds(isize),
48 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
49 Duration(String),
50}
51
52impl KeepAlive {
53 /// Keep model alive until a new model is loaded or until Ollama shuts down
54 fn indefinite() -> Self {
55 Self::Seconds(-1)
56 }
57}
58
59impl Default for KeepAlive {
60 fn default() -> Self {
61 Self::indefinite()
62 }
63}
64
65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
67pub struct Model {
68 pub name: String,
69 pub display_name: Option<String>,
70 pub max_tokens: usize,
71 pub keep_alive: Option<KeepAlive>,
72}
73
74fn get_max_tokens(name: &str) -> usize {
75 /// Default context length for unknown models.
76 const DEFAULT_TOKENS: usize = 2048;
77 /// Magic number. Lets many Ollama models work with ~16GB of ram.
78 const MAXIMUM_TOKENS: usize = 16384;
79
80 match name.split(':').next().unwrap() {
81 "phi" | "tinyllama" | "granite-code" => 2048,
82 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
83 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
84 "codellama" | "starcoder2" => 16384,
85 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
86 "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" => 128000,
87 _ => DEFAULT_TOKENS,
88 }
89 .clamp(1, MAXIMUM_TOKENS)
90}
91
92impl Model {
93 pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
94 Self {
95 name: name.to_owned(),
96 display_name: display_name
97 .map(ToString::to_string)
98 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
99 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
100 keep_alive: Some(KeepAlive::indefinite()),
101 }
102 }
103
104 pub fn id(&self) -> &str {
105 &self.name
106 }
107
108 pub fn display_name(&self) -> &str {
109 self.display_name.as_ref().unwrap_or(&self.name)
110 }
111
112 pub fn max_token_count(&self) -> usize {
113 self.max_tokens
114 }
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118#[serde(tag = "role", rename_all = "lowercase")]
119pub enum ChatMessage {
120 Assistant {
121 content: String,
122 tool_calls: Option<Vec<OllamaToolCall>>,
123 },
124 User {
125 content: String,
126 },
127 System {
128 content: String,
129 },
130}
131
132#[derive(Serialize, Deserialize, Debug)]
133#[serde(rename_all = "lowercase")]
134pub enum OllamaToolCall {
135 Function(OllamaFunctionCall),
136}
137
138#[derive(Serialize, Deserialize, Debug)]
139pub struct OllamaFunctionCall {
140 pub name: String,
141 pub arguments: Box<RawValue>,
142}
143
144#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
145pub struct OllamaFunctionTool {
146 pub name: String,
147 pub description: Option<String>,
148 pub parameters: Option<Value>,
149}
150
151#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
152#[serde(tag = "type", rename_all = "lowercase")]
153pub enum OllamaTool {
154 Function { function: OllamaFunctionTool },
155}
156
157#[derive(Serialize, Debug)]
158pub struct ChatRequest {
159 pub model: String,
160 pub messages: Vec<ChatMessage>,
161 pub stream: bool,
162 pub keep_alive: KeepAlive,
163 pub options: Option<ChatOptions>,
164 pub tools: Vec<OllamaTool>,
165}
166
167impl ChatRequest {
168 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
169 self.stream = false;
170 self.tools = tools;
171 self
172 }
173}
174
175// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
176#[derive(Serialize, Default, Debug)]
177pub struct ChatOptions {
178 pub num_ctx: Option<usize>,
179 pub num_predict: Option<isize>,
180 pub stop: Option<Vec<String>>,
181 pub temperature: Option<f32>,
182 pub top_p: Option<f32>,
183}
184
185#[derive(Deserialize, Debug)]
186pub struct ChatResponseDelta {
187 #[allow(unused)]
188 pub model: String,
189 #[allow(unused)]
190 pub created_at: String,
191 pub message: ChatMessage,
192 #[allow(unused)]
193 pub done_reason: Option<String>,
194 #[allow(unused)]
195 pub done: bool,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct LocalModelsResponse {
200 pub models: Vec<LocalModelListing>,
201}
202
203#[derive(Serialize, Deserialize)]
204pub struct LocalModelListing {
205 pub name: String,
206 pub modified_at: String,
207 pub size: u64,
208 pub digest: String,
209 pub details: ModelDetails,
210}
211
212#[derive(Serialize, Deserialize)]
213pub struct LocalModel {
214 pub modelfile: String,
215 pub parameters: String,
216 pub template: String,
217 pub details: ModelDetails,
218}
219
220#[derive(Serialize, Deserialize)]
221pub struct ModelDetails {
222 pub format: String,
223 pub family: String,
224 pub families: Option<Vec<String>>,
225 pub parameter_size: String,
226 pub quantization_level: String,
227}
228
229pub async fn complete(
230 client: &dyn HttpClient,
231 api_url: &str,
232 request: ChatRequest,
233) -> Result<ChatResponseDelta> {
234 let uri = format!("{api_url}/api/chat");
235 let request_builder = HttpRequest::builder()
236 .method(Method::POST)
237 .uri(uri)
238 .header("Content-Type", "application/json");
239
240 let serialized_request = serde_json::to_string(&request)?;
241 let request = request_builder.body(AsyncBody::from(serialized_request))?;
242
243 let mut response = client.send(request).await?;
244 if response.status().is_success() {
245 let mut body = Vec::new();
246 response.body_mut().read_to_end(&mut body).await?;
247 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
248 Ok(response_message)
249 } else {
250 let mut body = Vec::new();
251 response.body_mut().read_to_end(&mut body).await?;
252 let body_str = std::str::from_utf8(&body)?;
253 Err(anyhow!(
254 "Failed to connect to API: {} {}",
255 response.status(),
256 body_str
257 ))
258 }
259}
260
261pub async fn stream_chat_completion(
262 client: &dyn HttpClient,
263 api_url: &str,
264 request: ChatRequest,
265 low_speed_timeout: Option<Duration>,
266) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
267 let uri = format!("{api_url}/api/chat");
268 let mut request_builder = HttpRequest::builder()
269 .method(Method::POST)
270 .uri(uri)
271 .header("Content-Type", "application/json");
272
273 if let Some(low_speed_timeout) = low_speed_timeout {
274 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
275 };
276
277 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
278 let mut response = client.send(request).await?;
279 if response.status().is_success() {
280 let reader = BufReader::new(response.into_body());
281
282 Ok(reader
283 .lines()
284 .filter_map(|line| async move {
285 match line {
286 Ok(line) => {
287 Some(serde_json::from_str(&line).context("Unable to parse chat response"))
288 }
289 Err(e) => Some(Err(e.into())),
290 }
291 })
292 .boxed())
293 } else {
294 let mut body = String::new();
295 response.body_mut().read_to_string(&mut body).await?;
296
297 Err(anyhow!(
298 "Failed to connect to Ollama API: {} {}",
299 response.status(),
300 body,
301 ))
302 }
303}
304
305pub async fn get_models(
306 client: &dyn HttpClient,
307 api_url: &str,
308 low_speed_timeout: Option<Duration>,
309) -> Result<Vec<LocalModelListing>> {
310 let uri = format!("{api_url}/api/tags");
311 let mut request_builder = HttpRequest::builder()
312 .method(Method::GET)
313 .uri(uri)
314 .header("Accept", "application/json");
315
316 if let Some(low_speed_timeout) = low_speed_timeout {
317 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
318 };
319
320 let request = request_builder.body(AsyncBody::default())?;
321
322 let mut response = client.send(request).await?;
323
324 let mut body = String::new();
325 response.body_mut().read_to_string(&mut body).await?;
326
327 if response.status().is_success() {
328 let response: LocalModelsResponse =
329 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
330
331 Ok(response.models)
332 } else {
333 Err(anyhow!(
334 "Failed to connect to Ollama API: {} {}",
335 response.status(),
336 body,
337 ))
338 }
339}
340
341/// Sends an empty request to Ollama to trigger loading the model
342pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
343 let uri = format!("{api_url}/api/generate");
344 let request = HttpRequest::builder()
345 .method(Method::POST)
346 .uri(uri)
347 .header("Content-Type", "application/json")
348 .body(AsyncBody::from(serde_json::to_string(
349 &serde_json::json!({
350 "model": model,
351 "keep_alive": "15m",
352 }),
353 )?))?;
354
355 let mut response = match client.send(request).await {
356 Ok(response) => response,
357 Err(err) => {
358 // Be ok with a timeout during preload of the model
359 if err.is_timeout() {
360 return Ok(());
361 } else {
362 return Err(err.into());
363 }
364 }
365 };
366
367 if response.status().is_success() {
368 Ok(())
369 } else {
370 let mut body = String::new();
371 response.body_mut().read_to_string(&mut body).await?;
372
373 Err(anyhow!(
374 "Failed to connect to Ollama API: {} {}",
375 response.status(),
376 body,
377 ))
378 }
379}