1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http_client::{http, AsyncBody, HttpClient, Method, Request as HttpRequest};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_json::{value::RawValue, Value};
7use std::{convert::TryFrom, sync::Arc, time::Duration};
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 _ => Err(anyhow!("invalid role '{value}'")),
28 }
29 }
30}
31
32impl From<Role> for String {
33 fn from(val: Role) -> Self {
34 match val {
35 Role::User => "user".to_owned(),
36 Role::Assistant => "assistant".to_owned(),
37 Role::System => "system".to_owned(),
38 }
39 }
40}
41
42#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
43#[serde(untagged)]
44pub enum KeepAlive {
45 /// Keep model alive for N seconds
46 Seconds(isize),
47 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
48 Duration(String),
49}
50
51impl KeepAlive {
52 /// Keep model alive until a new model is loaded or until Ollama shuts down
53 fn indefinite() -> Self {
54 Self::Seconds(-1)
55 }
56}
57
58impl Default for KeepAlive {
59 fn default() -> Self {
60 Self::indefinite()
61 }
62}
63
64#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
66pub struct Model {
67 pub name: String,
68 pub display_name: Option<String>,
69 pub max_tokens: usize,
70 pub keep_alive: Option<KeepAlive>,
71}
72
73fn get_max_tokens(name: &str) -> usize {
74 /// Default context length for unknown models.
75 const DEFAULT_TOKENS: usize = 2048;
76 /// Magic number. Lets many Ollama models work with ~16GB of ram.
77 const MAXIMUM_TOKENS: usize = 16384;
78
79 match name.split(':').next().unwrap() {
80 "phi" | "tinyllama" | "granite-code" => 2048,
81 "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
82 "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
83 "codellama" | "starcoder2" => 16384,
84 "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
85 "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" | "yi-coder" => 128000,
86 _ => DEFAULT_TOKENS,
87 }
88 .clamp(1, MAXIMUM_TOKENS)
89}
90
91impl Model {
92 pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
93 Self {
94 name: name.to_owned(),
95 display_name: display_name
96 .map(ToString::to_string)
97 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
98 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
99 keep_alive: Some(KeepAlive::indefinite()),
100 }
101 }
102
103 pub fn id(&self) -> &str {
104 &self.name
105 }
106
107 pub fn display_name(&self) -> &str {
108 self.display_name.as_ref().unwrap_or(&self.name)
109 }
110
111 pub fn max_token_count(&self) -> usize {
112 self.max_tokens
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117#[serde(tag = "role", rename_all = "lowercase")]
118pub enum ChatMessage {
119 Assistant {
120 content: String,
121 tool_calls: Option<Vec<OllamaToolCall>>,
122 },
123 User {
124 content: String,
125 },
126 System {
127 content: String,
128 },
129}
130
131#[derive(Serialize, Deserialize, Debug)]
132#[serde(rename_all = "lowercase")]
133pub enum OllamaToolCall {
134 Function(OllamaFunctionCall),
135}
136
137#[derive(Serialize, Deserialize, Debug)]
138pub struct OllamaFunctionCall {
139 pub name: String,
140 pub arguments: Box<RawValue>,
141}
142
143#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
144pub struct OllamaFunctionTool {
145 pub name: String,
146 pub description: Option<String>,
147 pub parameters: Option<Value>,
148}
149
150#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
151#[serde(tag = "type", rename_all = "lowercase")]
152pub enum OllamaTool {
153 Function { function: OllamaFunctionTool },
154}
155
156#[derive(Serialize, Debug)]
157pub struct ChatRequest {
158 pub model: String,
159 pub messages: Vec<ChatMessage>,
160 pub stream: bool,
161 pub keep_alive: KeepAlive,
162 pub options: Option<ChatOptions>,
163 pub tools: Vec<OllamaTool>,
164}
165
166impl ChatRequest {
167 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
168 self.stream = false;
169 self.tools = tools;
170 self
171 }
172}
173
174// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
175#[derive(Serialize, Default, Debug)]
176pub struct ChatOptions {
177 pub num_ctx: Option<usize>,
178 pub num_predict: Option<isize>,
179 pub stop: Option<Vec<String>>,
180 pub temperature: Option<f32>,
181 pub top_p: Option<f32>,
182}
183
184#[derive(Deserialize, Debug)]
185pub struct ChatResponseDelta {
186 #[allow(unused)]
187 pub model: String,
188 #[allow(unused)]
189 pub created_at: String,
190 pub message: ChatMessage,
191 #[allow(unused)]
192 pub done_reason: Option<String>,
193 #[allow(unused)]
194 pub done: bool,
195}
196
197#[derive(Serialize, Deserialize)]
198pub struct LocalModelsResponse {
199 pub models: Vec<LocalModelListing>,
200}
201
202#[derive(Serialize, Deserialize)]
203pub struct LocalModelListing {
204 pub name: String,
205 pub modified_at: String,
206 pub size: u64,
207 pub digest: String,
208 pub details: ModelDetails,
209}
210
211#[derive(Serialize, Deserialize)]
212pub struct LocalModel {
213 pub modelfile: String,
214 pub parameters: String,
215 pub template: String,
216 pub details: ModelDetails,
217}
218
219#[derive(Serialize, Deserialize)]
220pub struct ModelDetails {
221 pub format: String,
222 pub family: String,
223 pub families: Option<Vec<String>>,
224 pub parameter_size: String,
225 pub quantization_level: String,
226}
227
228pub async fn complete(
229 client: &dyn HttpClient,
230 api_url: &str,
231 request: ChatRequest,
232) -> Result<ChatResponseDelta> {
233 let uri = format!("{api_url}/api/chat");
234 let request_builder = HttpRequest::builder()
235 .method(Method::POST)
236 .uri(uri)
237 .header("Content-Type", "application/json");
238
239 let serialized_request = serde_json::to_string(&request)?;
240 let request = request_builder.body(AsyncBody::from(serialized_request))?;
241
242 let mut response = client.send(request).await?;
243 if response.status().is_success() {
244 let mut body = Vec::new();
245 response.body_mut().read_to_end(&mut body).await?;
246 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
247 Ok(response_message)
248 } else {
249 let mut body = Vec::new();
250 response.body_mut().read_to_end(&mut body).await?;
251 let body_str = std::str::from_utf8(&body)?;
252 Err(anyhow!(
253 "Failed to connect to API: {} {}",
254 response.status(),
255 body_str
256 ))
257 }
258}
259
260pub async fn stream_chat_completion(
261 client: &dyn HttpClient,
262 api_url: &str,
263 request: ChatRequest,
264 _: Option<Duration>,
265) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
266 let uri = format!("{api_url}/api/chat");
267 let request_builder = http::Request::builder()
268 .method(Method::POST)
269 .uri(uri)
270 .header("Content-Type", "application/json");
271
272 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
273 let mut response = client.send(request).await?;
274 if response.status().is_success() {
275 let reader = BufReader::new(response.into_body());
276
277 Ok(reader
278 .lines()
279 .filter_map(|line| async move {
280 match line {
281 Ok(line) => {
282 Some(serde_json::from_str(&line).context("Unable to parse chat response"))
283 }
284 Err(e) => Some(Err(e.into())),
285 }
286 })
287 .boxed())
288 } else {
289 let mut body = String::new();
290 response.body_mut().read_to_string(&mut body).await?;
291
292 Err(anyhow!(
293 "Failed to connect to Ollama API: {} {}",
294 response.status(),
295 body,
296 ))
297 }
298}
299
300pub async fn get_models(
301 client: &dyn HttpClient,
302 api_url: &str,
303 _: Option<Duration>,
304) -> Result<Vec<LocalModelListing>> {
305 let uri = format!("{api_url}/api/tags");
306 let request_builder = HttpRequest::builder()
307 .method(Method::GET)
308 .uri(uri)
309 .header("Accept", "application/json");
310
311 let request = request_builder.body(AsyncBody::default())?;
312
313 let mut response = client.send(request).await?;
314
315 let mut body = String::new();
316 response.body_mut().read_to_string(&mut body).await?;
317
318 if response.status().is_success() {
319 let response: LocalModelsResponse =
320 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
321
322 Ok(response.models)
323 } else {
324 Err(anyhow!(
325 "Failed to connect to Ollama API: {} {}",
326 response.status(),
327 body,
328 ))
329 }
330}
331
332/// Sends an empty request to Ollama to trigger loading the model
333pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
334 let uri = format!("{api_url}/api/generate");
335 let request = HttpRequest::builder()
336 .method(Method::POST)
337 .uri(uri)
338 .header("Content-Type", "application/json")
339 .body(AsyncBody::from(serde_json::to_string(
340 &serde_json::json!({
341 "model": model,
342 "keep_alive": "15m",
343 }),
344 )?))?;
345
346 let mut response = client.send(request).await?;
347
348 if response.status().is_success() {
349 Ok(())
350 } else {
351 let mut body = String::new();
352 response.body_mut().read_to_string(&mut body).await?;
353
354 Err(anyhow!(
355 "Failed to connect to Ollama API: {} {}",
356 response.status(),
357 body,
358 ))
359 }
360}