1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http_client::{http, AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::{value::RawValue, Value};
8use std::{convert::TryFrom, sync::Arc, time::Duration};
9
10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 _ => Err(anyhow!("invalid role '{value}'")),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 }
40 }
41}
42
43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
44#[serde(untagged)]
45pub enum KeepAlive {
46 /// Keep model alive for N seconds
47 Seconds(isize),
48 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
49 Duration(String),
50}
51
52impl KeepAlive {
53 /// Keep model alive until a new model is loaded or until Ollama shuts down
54 fn indefinite() -> Self {
55 Self::Seconds(-1)
56 }
57}
58
59impl Default for KeepAlive {
60 fn default() -> Self {
61 Self::indefinite()
62 }
63}
64
65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
67pub struct Model {
68 pub name: String,
69 pub display_name: Option<String>,
70 pub max_tokens: usize,
71 pub keep_alive: Option<KeepAlive>,
72}
73
74fn get_max_tokens(name: &str) -> usize {
75 /// Default context length for unknown models.
76 const DEFAULT_TOKENS: usize = 2048;
77 /// Magic number. Lets many Ollama models work with ~16GB of ram.
78 const MAXIMUM_TOKENS: usize = 16384;
79
80 match name.split(':').next().unwrap() {
81 "phi" | "tinyllama" | "granite-code" => 2048,
82 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
83 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
84 "codellama" | "starcoder2" => 16384,
85 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
86 "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" | "yi-coder"
87 | "qwen2.5-coder" => 128000,
88 _ => DEFAULT_TOKENS,
89 }
90 .clamp(1, MAXIMUM_TOKENS)
91}
92
93impl Model {
94 pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
95 Self {
96 name: name.to_owned(),
97 display_name: display_name
98 .map(ToString::to_string)
99 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
100 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
101 keep_alive: Some(KeepAlive::indefinite()),
102 }
103 }
104
105 pub fn id(&self) -> &str {
106 &self.name
107 }
108
109 pub fn display_name(&self) -> &str {
110 self.display_name.as_ref().unwrap_or(&self.name)
111 }
112
113 pub fn max_token_count(&self) -> usize {
114 self.max_tokens
115 }
116}
117
118#[derive(Serialize, Deserialize, Debug)]
119#[serde(tag = "role", rename_all = "lowercase")]
120pub enum ChatMessage {
121 Assistant {
122 content: String,
123 tool_calls: Option<Vec<OllamaToolCall>>,
124 },
125 User {
126 content: String,
127 },
128 System {
129 content: String,
130 },
131}
132
133#[derive(Serialize, Deserialize, Debug)]
134#[serde(rename_all = "lowercase")]
135pub enum OllamaToolCall {
136 Function(OllamaFunctionCall),
137}
138
139#[derive(Serialize, Deserialize, Debug)]
140pub struct OllamaFunctionCall {
141 pub name: String,
142 pub arguments: Box<RawValue>,
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146pub struct OllamaFunctionTool {
147 pub name: String,
148 pub description: Option<String>,
149 pub parameters: Option<Value>,
150}
151
152#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
153#[serde(tag = "type", rename_all = "lowercase")]
154pub enum OllamaTool {
155 Function { function: OllamaFunctionTool },
156}
157
158#[derive(Serialize, Debug)]
159pub struct ChatRequest {
160 pub model: String,
161 pub messages: Vec<ChatMessage>,
162 pub stream: bool,
163 pub keep_alive: KeepAlive,
164 pub options: Option<ChatOptions>,
165 pub tools: Vec<OllamaTool>,
166}
167
168impl ChatRequest {
169 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
170 self.stream = false;
171 self.tools = tools;
172 self
173 }
174}
175
176// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
177#[derive(Serialize, Default, Debug)]
178pub struct ChatOptions {
179 pub num_ctx: Option<usize>,
180 pub num_predict: Option<isize>,
181 pub stop: Option<Vec<String>>,
182 pub temperature: Option<f32>,
183 pub top_p: Option<f32>,
184}
185
186#[derive(Deserialize, Debug)]
187pub struct ChatResponseDelta {
188 #[allow(unused)]
189 pub model: String,
190 #[allow(unused)]
191 pub created_at: String,
192 pub message: ChatMessage,
193 #[allow(unused)]
194 pub done_reason: Option<String>,
195 #[allow(unused)]
196 pub done: bool,
197}
198
199#[derive(Serialize, Deserialize)]
200pub struct LocalModelsResponse {
201 pub models: Vec<LocalModelListing>,
202}
203
204#[derive(Serialize, Deserialize)]
205pub struct LocalModelListing {
206 pub name: String,
207 pub modified_at: String,
208 pub size: u64,
209 pub digest: String,
210 pub details: ModelDetails,
211}
212
213#[derive(Serialize, Deserialize)]
214pub struct LocalModel {
215 pub modelfile: String,
216 pub parameters: String,
217 pub template: String,
218 pub details: ModelDetails,
219}
220
221#[derive(Serialize, Deserialize)]
222pub struct ModelDetails {
223 pub format: String,
224 pub family: String,
225 pub families: Option<Vec<String>>,
226 pub parameter_size: String,
227 pub quantization_level: String,
228}
229
230pub async fn complete(
231 client: &dyn HttpClient,
232 api_url: &str,
233 request: ChatRequest,
234) -> Result<ChatResponseDelta> {
235 let uri = format!("{api_url}/api/chat");
236 let request_builder = HttpRequest::builder()
237 .method(Method::POST)
238 .uri(uri)
239 .header("Content-Type", "application/json");
240
241 let serialized_request = serde_json::to_string(&request)?;
242 let request = request_builder.body(AsyncBody::from(serialized_request))?;
243
244 let mut response = client.send(request).await?;
245 if response.status().is_success() {
246 let mut body = Vec::new();
247 response.body_mut().read_to_end(&mut body).await?;
248 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
249 Ok(response_message)
250 } else {
251 let mut body = Vec::new();
252 response.body_mut().read_to_end(&mut body).await?;
253 let body_str = std::str::from_utf8(&body)?;
254 Err(anyhow!(
255 "Failed to connect to API: {} {}",
256 response.status(),
257 body_str
258 ))
259 }
260}
261
262pub async fn stream_chat_completion(
263 client: &dyn HttpClient,
264 api_url: &str,
265 request: ChatRequest,
266 low_speed_timeout: Option<Duration>,
267) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
268 let uri = format!("{api_url}/api/chat");
269 let mut request_builder = http::Request::builder()
270 .method(Method::POST)
271 .uri(uri)
272 .header("Content-Type", "application/json");
273
274 if let Some(low_speed_timeout) = low_speed_timeout {
275 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
276 };
277
278 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
279 let mut response = client.send(request).await?;
280 if response.status().is_success() {
281 let reader = BufReader::new(response.into_body());
282
283 Ok(reader
284 .lines()
285 .filter_map(|line| async move {
286 match line {
287 Ok(line) => {
288 Some(serde_json::from_str(&line).context("Unable to parse chat response"))
289 }
290 Err(e) => Some(Err(e.into())),
291 }
292 })
293 .boxed())
294 } else {
295 let mut body = String::new();
296 response.body_mut().read_to_string(&mut body).await?;
297
298 Err(anyhow!(
299 "Failed to connect to Ollama API: {} {}",
300 response.status(),
301 body,
302 ))
303 }
304}
305
306pub async fn get_models(
307 client: &dyn HttpClient,
308 api_url: &str,
309 _: Option<Duration>,
310) -> Result<Vec<LocalModelListing>> {
311 let uri = format!("{api_url}/api/tags");
312 let request_builder = HttpRequest::builder()
313 .method(Method::GET)
314 .uri(uri)
315 .header("Accept", "application/json");
316
317 let request = request_builder.body(AsyncBody::default())?;
318
319 let mut response = client.send(request).await?;
320
321 let mut body = String::new();
322 response.body_mut().read_to_string(&mut body).await?;
323
324 if response.status().is_success() {
325 let response: LocalModelsResponse =
326 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
327
328 Ok(response.models)
329 } else {
330 Err(anyhow!(
331 "Failed to connect to Ollama API: {} {}",
332 response.status(),
333 body,
334 ))
335 }
336}
337
338/// Sends an empty request to Ollama to trigger loading the model
339pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
340 let uri = format!("{api_url}/api/generate");
341 let request = HttpRequest::builder()
342 .method(Method::POST)
343 .uri(uri)
344 .header("Content-Type", "application/json")
345 .body(AsyncBody::from(serde_json::to_string(
346 &serde_json::json!({
347 "model": model,
348 "keep_alive": "15m",
349 }),
350 )?))?;
351
352 let mut response = client.send(request).await?;
353
354 if response.status().is_success() {
355 Ok(())
356 } else {
357 let mut body = String::new();
358 response.body_mut().read_to_string(&mut body).await?;
359
360 Err(anyhow!(
361 "Failed to connect to Ollama API: {} {}",
362 response.status(),
363 body,
364 ))
365 }
366}