1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::{convert::TryFrom, time::Duration};
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 _ => Err(anyhow!("invalid role '{value}'")),
28 }
29 }
30}
31
32impl From<Role> for String {
33 fn from(val: Role) -> Self {
34 match val {
35 Role::User => "user".to_owned(),
36 Role::Assistant => "assistant".to_owned(),
37 Role::System => "system".to_owned(),
38 }
39 }
40}
41
42#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
43#[serde(untagged)]
44pub enum KeepAlive {
45 /// Keep model alive for N seconds
46 Seconds(isize),
47 /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
48 Duration(String),
49}
50
51impl KeepAlive {
52 /// Keep model alive until a new model is loaded or until Ollama shuts down
53 fn indefinite() -> Self {
54 Self::Seconds(-1)
55 }
56}
57
58impl Default for KeepAlive {
59 fn default() -> Self {
60 Self::indefinite()
61 }
62}
63
64#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
66pub struct Model {
67 pub name: String,
68 pub max_tokens: usize,
69 pub keep_alive: Option<KeepAlive>,
70}
71
72impl Model {
73 pub fn new(name: &str) -> Self {
74 Self {
75 name: name.to_owned(),
76 max_tokens: 2048,
77 keep_alive: Some(KeepAlive::indefinite()),
78 }
79 }
80
81 pub fn id(&self) -> &str {
82 &self.name
83 }
84
85 pub fn display_name(&self) -> &str {
86 &self.name
87 }
88
89 pub fn max_token_count(&self) -> usize {
90 self.max_tokens
91 }
92}
93
94#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
95#[serde(tag = "role", rename_all = "lowercase")]
96pub enum ChatMessage {
97 Assistant { content: String },
98 User { content: String },
99 System { content: String },
100}
101
102#[derive(Serialize)]
103pub struct ChatRequest {
104 pub model: String,
105 pub messages: Vec<ChatMessage>,
106 pub stream: bool,
107 pub keep_alive: KeepAlive,
108 pub options: Option<ChatOptions>,
109}
110
111// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
112#[derive(Serialize, Default)]
113pub struct ChatOptions {
114 pub num_ctx: Option<usize>,
115 pub num_predict: Option<isize>,
116 pub stop: Option<Vec<String>>,
117 pub temperature: Option<f32>,
118 pub top_p: Option<f32>,
119}
120
121#[derive(Deserialize)]
122pub struct ChatResponseDelta {
123 #[allow(unused)]
124 pub model: String,
125 #[allow(unused)]
126 pub created_at: String,
127 pub message: ChatMessage,
128 #[allow(unused)]
129 pub done_reason: Option<String>,
130 #[allow(unused)]
131 pub done: bool,
132}
133
134#[derive(Serialize, Deserialize)]
135pub struct LocalModelsResponse {
136 pub models: Vec<LocalModelListing>,
137}
138
139#[derive(Serialize, Deserialize)]
140pub struct LocalModelListing {
141 pub name: String,
142 pub modified_at: String,
143 pub size: u64,
144 pub digest: String,
145 pub details: ModelDetails,
146}
147
148#[derive(Serialize, Deserialize)]
149pub struct LocalModel {
150 pub modelfile: String,
151 pub parameters: String,
152 pub template: String,
153 pub details: ModelDetails,
154}
155
156#[derive(Serialize, Deserialize)]
157pub struct ModelDetails {
158 pub format: String,
159 pub family: String,
160 pub families: Option<Vec<String>>,
161 pub parameter_size: String,
162 pub quantization_level: String,
163}
164
165pub async fn stream_chat_completion(
166 client: &dyn HttpClient,
167 api_url: &str,
168 request: ChatRequest,
169 low_speed_timeout: Option<Duration>,
170) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
171 let uri = format!("{api_url}/api/chat");
172 let mut request_builder = HttpRequest::builder()
173 .method(Method::POST)
174 .uri(uri)
175 .header("Content-Type", "application/json");
176
177 if let Some(low_speed_timeout) = low_speed_timeout {
178 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
179 };
180
181 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
182 let mut response = client.send(request).await?;
183 if response.status().is_success() {
184 let reader = BufReader::new(response.into_body());
185
186 Ok(reader
187 .lines()
188 .filter_map(|line| async move {
189 match line {
190 Ok(line) => {
191 Some(serde_json::from_str(&line).context("Unable to parse chat response"))
192 }
193 Err(e) => Some(Err(e.into())),
194 }
195 })
196 .boxed())
197 } else {
198 let mut body = String::new();
199 response.body_mut().read_to_string(&mut body).await?;
200
201 Err(anyhow!(
202 "Failed to connect to Ollama API: {} {}",
203 response.status(),
204 body,
205 ))
206 }
207}
208
209pub async fn get_models(
210 client: &dyn HttpClient,
211 api_url: &str,
212 low_speed_timeout: Option<Duration>,
213) -> Result<Vec<LocalModelListing>> {
214 let uri = format!("{api_url}/api/tags");
215 let mut request_builder = HttpRequest::builder()
216 .method(Method::GET)
217 .uri(uri)
218 .header("Accept", "application/json");
219
220 if let Some(low_speed_timeout) = low_speed_timeout {
221 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
222 };
223
224 let request = request_builder.body(AsyncBody::default())?;
225
226 let mut response = client.send(request).await?;
227
228 let mut body = String::new();
229 response.body_mut().read_to_string(&mut body).await?;
230
231 if response.status().is_success() {
232 let response: LocalModelsResponse =
233 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
234
235 Ok(response.models)
236 } else {
237 Err(anyhow!(
238 "Failed to connect to Ollama API: {} {}",
239 response.status(),
240 body,
241 ))
242 }
243}
244
245/// Sends an empty request to Ollama to trigger loading the model
246pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
247 let uri = format!("{api_url}/api/generate");
248 let request = HttpRequest::builder()
249 .method(Method::POST)
250 .uri(uri)
251 .header("Content-Type", "application/json")
252 .body(AsyncBody::from(serde_json::to_string(
253 &serde_json::json!({
254 "model": model,
255 "keep_alive": "15m",
256 }),
257 )?))?;
258
259 let mut response = match client.send(request).await {
260 Ok(response) => response,
261 Err(err) => {
262 // Be ok with a timeout during preload of the model
263 if err.is_timeout() {
264 return Ok(());
265 } else {
266 return Err(err.into());
267 }
268 }
269 };
270
271 if response.status().is_success() {
272 Ok(())
273 } else {
274 let mut body = String::new();
275 response.body_mut().read_to_string(&mut body).await?;
276
277 Err(anyhow!(
278 "Failed to connect to Ollama API: {} {}",
279 response.status(),
280 body,
281 ))
282 }
283}