1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::{convert::TryFrom, sync::Arc, time::Duration};
9
10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 _ => Err(anyhow!("invalid role '{value}'")),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 }
40 }
41}
42
43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
44#[serde(untagged)]
45pub enum KeepAlive {
46 /// Keep model alive for N seconds
47 Seconds(isize),
48 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
49 Duration(String),
50}
51
52impl KeepAlive {
53 /// Keep model alive until a new model is loaded or until Ollama shuts down
54 fn indefinite() -> Self {
55 Self::Seconds(-1)
56 }
57}
58
59impl Default for KeepAlive {
60 fn default() -> Self {
61 Self::indefinite()
62 }
63}
64
65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
67pub struct Model {
68 pub name: String,
69 pub max_tokens: usize,
70 pub keep_alive: Option<KeepAlive>,
71}
72
73impl Model {
74 pub fn new(name: &str) -> Self {
75 Self {
76 name: name.to_owned(),
77 max_tokens: 2048,
78 keep_alive: Some(KeepAlive::indefinite()),
79 }
80 }
81
82 pub fn id(&self) -> &str {
83 &self.name
84 }
85
86 pub fn display_name(&self) -> &str {
87 &self.name
88 }
89
90 pub fn max_token_count(&self) -> usize {
91 self.max_tokens
92 }
93}
94
95#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
96#[serde(tag = "role", rename_all = "lowercase")]
97pub enum ChatMessage {
98 Assistant {
99 content: String,
100 tool_calls: Option<Vec<OllamaToolCall>>,
101 },
102 User {
103 content: String,
104 },
105 System {
106 content: String,
107 },
108}
109
110#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
111#[serde(rename_all = "lowercase")]
112pub enum OllamaToolCall {
113 Function(OllamaFunctionCall),
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionCall {
118 pub name: String,
119 pub arguments: Value,
120}
121
122#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
123pub struct OllamaFunctionTool {
124 pub name: String,
125 pub description: Option<String>,
126 pub parameters: Option<Value>,
127}
128
129#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
130#[serde(tag = "type", rename_all = "lowercase")]
131pub enum OllamaTool {
132 Function { function: OllamaFunctionTool },
133}
134
135#[derive(Serialize, Debug)]
136pub struct ChatRequest {
137 pub model: String,
138 pub messages: Vec<ChatMessage>,
139 pub stream: bool,
140 pub keep_alive: KeepAlive,
141 pub options: Option<ChatOptions>,
142 pub tools: Vec<OllamaTool>,
143}
144
145impl ChatRequest {
146 pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
147 self.stream = false;
148 self.tools = tools;
149 self
150 }
151}
152
153// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
154#[derive(Serialize, Default, Debug)]
155pub struct ChatOptions {
156 pub num_ctx: Option<usize>,
157 pub num_predict: Option<isize>,
158 pub stop: Option<Vec<String>>,
159 pub temperature: Option<f32>,
160 pub top_p: Option<f32>,
161}
162
163#[derive(Deserialize, Debug)]
164pub struct ChatResponseDelta {
165 #[allow(unused)]
166 pub model: String,
167 #[allow(unused)]
168 pub created_at: String,
169 pub message: ChatMessage,
170 #[allow(unused)]
171 pub done_reason: Option<String>,
172 #[allow(unused)]
173 pub done: bool,
174}
175
176#[derive(Serialize, Deserialize)]
177pub struct LocalModelsResponse {
178 pub models: Vec<LocalModelListing>,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct LocalModelListing {
183 pub name: String,
184 pub modified_at: String,
185 pub size: u64,
186 pub digest: String,
187 pub details: ModelDetails,
188}
189
190#[derive(Serialize, Deserialize)]
191pub struct LocalModel {
192 pub modelfile: String,
193 pub parameters: String,
194 pub template: String,
195 pub details: ModelDetails,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct ModelDetails {
200 pub format: String,
201 pub family: String,
202 pub families: Option<Vec<String>>,
203 pub parameter_size: String,
204 pub quantization_level: String,
205}
206
207pub async fn complete(
208 client: &dyn HttpClient,
209 api_url: &str,
210 request: ChatRequest,
211) -> Result<ChatResponseDelta> {
212 let uri = format!("{api_url}/api/chat");
213 let request_builder = HttpRequest::builder()
214 .method(Method::POST)
215 .uri(uri)
216 .header("Content-Type", "application/json");
217
218 let serialized_request = serde_json::to_string(&request)?;
219 let request = request_builder.body(AsyncBody::from(serialized_request))?;
220
221 let mut response = client.send(request).await?;
222 if response.status().is_success() {
223 let mut body = Vec::new();
224 response.body_mut().read_to_end(&mut body).await?;
225 let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
226 Ok(response_message)
227 } else {
228 let mut body = Vec::new();
229 response.body_mut().read_to_end(&mut body).await?;
230 let body_str = std::str::from_utf8(&body)?;
231 Err(anyhow!(
232 "Failed to connect to API: {} {}",
233 response.status(),
234 body_str
235 ))
236 }
237}
238
239pub async fn stream_chat_completion(
240 client: &dyn HttpClient,
241 api_url: &str,
242 request: ChatRequest,
243 low_speed_timeout: Option<Duration>,
244) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
245 let uri = format!("{api_url}/api/chat");
246 let mut request_builder = HttpRequest::builder()
247 .method(Method::POST)
248 .uri(uri)
249 .header("Content-Type", "application/json");
250
251 if let Some(low_speed_timeout) = low_speed_timeout {
252 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
253 };
254
255 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
256 let mut response = client.send(request).await?;
257 if response.status().is_success() {
258 let reader = BufReader::new(response.into_body());
259
260 Ok(reader
261 .lines()
262 .filter_map(|line| async move {
263 match line {
264 Ok(line) => {
265 Some(serde_json::from_str(&line).context("Unable to parse chat response"))
266 }
267 Err(e) => Some(Err(e.into())),
268 }
269 })
270 .boxed())
271 } else {
272 let mut body = String::new();
273 response.body_mut().read_to_string(&mut body).await?;
274
275 Err(anyhow!(
276 "Failed to connect to Ollama API: {} {}",
277 response.status(),
278 body,
279 ))
280 }
281}
282
283pub async fn get_models(
284 client: &dyn HttpClient,
285 api_url: &str,
286 low_speed_timeout: Option<Duration>,
287) -> Result<Vec<LocalModelListing>> {
288 let uri = format!("{api_url}/api/tags");
289 let mut request_builder = HttpRequest::builder()
290 .method(Method::GET)
291 .uri(uri)
292 .header("Accept", "application/json");
293
294 if let Some(low_speed_timeout) = low_speed_timeout {
295 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
296 };
297
298 let request = request_builder.body(AsyncBody::default())?;
299
300 let mut response = client.send(request).await?;
301
302 let mut body = String::new();
303 response.body_mut().read_to_string(&mut body).await?;
304
305 if response.status().is_success() {
306 let response: LocalModelsResponse =
307 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
308
309 Ok(response.models)
310 } else {
311 Err(anyhow!(
312 "Failed to connect to Ollama API: {} {}",
313 response.status(),
314 body,
315 ))
316 }
317}
318
319/// Sends an empty request to Ollama to trigger loading the model
320pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
321 let uri = format!("{api_url}/api/generate");
322 let request = HttpRequest::builder()
323 .method(Method::POST)
324 .uri(uri)
325 .header("Content-Type", "application/json")
326 .body(AsyncBody::from(serde_json::to_string(
327 &serde_json::json!({
328 "model": model,
329 "keep_alive": "15m",
330 }),
331 )?))?;
332
333 let mut response = match client.send(request).await {
334 Ok(response) => response,
335 Err(err) => {
336 // Be ok with a timeout during preload of the model
337 if err.is_timeout() {
338 return Ok(());
339 } else {
340 return Err(err.into());
341 }
342 }
343 };
344
345 if response.status().is_success() {
346 Ok(())
347 } else {
348 let mut body = String::new();
349 response.body_mut().read_to_string(&mut body).await?;
350
351 Err(anyhow!(
352 "Failed to connect to Ollama API: {} {}",
353 response.status(),
354 body,
355 ))
356 }
357}