ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use serde_json::Value;
  8use std::{convert::TryFrom, sync::Arc, time::Duration};
  9
 10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            _ => Err(anyhow!("invalid role '{value}'")),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39        }
 40    }
 41}
 42
 43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 44#[serde(untagged)]
 45pub enum KeepAlive {
 46    /// Keep model alive for N seconds
 47    Seconds(isize),
 48    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 49    Duration(String),
 50}
 51
 52impl KeepAlive {
 53    /// Keep model alive until a new model is loaded or until Ollama shuts down
 54    fn indefinite() -> Self {
 55        Self::Seconds(-1)
 56    }
 57}
 58
 59impl Default for KeepAlive {
 60    fn default() -> Self {
 61        Self::indefinite()
 62    }
 63}
 64
 65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 67pub struct Model {
 68    pub name: String,
 69    pub max_tokens: usize,
 70    pub keep_alive: Option<KeepAlive>,
 71}
 72
 73impl Model {
 74    pub fn new(name: &str) -> Self {
 75        Self {
 76            name: name.to_owned(),
 77            max_tokens: 2048,
 78            keep_alive: Some(KeepAlive::indefinite()),
 79        }
 80    }
 81
 82    pub fn id(&self) -> &str {
 83        &self.name
 84    }
 85
 86    pub fn display_name(&self) -> &str {
 87        &self.name
 88    }
 89
 90    pub fn max_token_count(&self) -> usize {
 91        self.max_tokens
 92    }
 93}
 94
 95#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 96#[serde(tag = "role", rename_all = "lowercase")]
 97pub enum ChatMessage {
 98    Assistant {
 99        content: String,
100        tool_calls: Option<Vec<OllamaToolCall>>,
101    },
102    User {
103        content: String,
104    },
105    System {
106        content: String,
107    },
108}
109
110#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
111#[serde(rename_all = "lowercase")]
112pub enum OllamaToolCall {
113    Function(OllamaFunctionCall),
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct OllamaFunctionCall {
118    pub name: String,
119    pub arguments: Value,
120}
121
122#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
123pub struct OllamaFunctionTool {
124    pub name: String,
125    pub description: Option<String>,
126    pub parameters: Option<Value>,
127}
128
129#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
130#[serde(tag = "type", rename_all = "lowercase")]
131pub enum OllamaTool {
132    Function { function: OllamaFunctionTool },
133}
134
135#[derive(Serialize, Debug)]
136pub struct ChatRequest {
137    pub model: String,
138    pub messages: Vec<ChatMessage>,
139    pub stream: bool,
140    pub keep_alive: KeepAlive,
141    pub options: Option<ChatOptions>,
142    pub tools: Vec<OllamaTool>,
143}
144
145impl ChatRequest {
146    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
147        self.stream = false;
148        self.tools = tools;
149        self
150    }
151}
152
153// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
154#[derive(Serialize, Default, Debug)]
155pub struct ChatOptions {
156    pub num_ctx: Option<usize>,
157    pub num_predict: Option<isize>,
158    pub stop: Option<Vec<String>>,
159    pub temperature: Option<f32>,
160    pub top_p: Option<f32>,
161}
162
163#[derive(Deserialize, Debug)]
164pub struct ChatResponseDelta {
165    #[allow(unused)]
166    pub model: String,
167    #[allow(unused)]
168    pub created_at: String,
169    pub message: ChatMessage,
170    #[allow(unused)]
171    pub done_reason: Option<String>,
172    #[allow(unused)]
173    pub done: bool,
174}
175
176#[derive(Serialize, Deserialize)]
177pub struct LocalModelsResponse {
178    pub models: Vec<LocalModelListing>,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct LocalModelListing {
183    pub name: String,
184    pub modified_at: String,
185    pub size: u64,
186    pub digest: String,
187    pub details: ModelDetails,
188}
189
190#[derive(Serialize, Deserialize)]
191pub struct LocalModel {
192    pub modelfile: String,
193    pub parameters: String,
194    pub template: String,
195    pub details: ModelDetails,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct ModelDetails {
200    pub format: String,
201    pub family: String,
202    pub families: Option<Vec<String>>,
203    pub parameter_size: String,
204    pub quantization_level: String,
205}
206
207pub async fn complete(
208    client: &dyn HttpClient,
209    api_url: &str,
210    request: ChatRequest,
211) -> Result<ChatResponseDelta> {
212    let uri = format!("{api_url}/api/chat");
213    let request_builder = HttpRequest::builder()
214        .method(Method::POST)
215        .uri(uri)
216        .header("Content-Type", "application/json");
217
218    let serialized_request = serde_json::to_string(&request)?;
219    let request = request_builder.body(AsyncBody::from(serialized_request))?;
220
221    let mut response = client.send(request).await?;
222    if response.status().is_success() {
223        let mut body = Vec::new();
224        response.body_mut().read_to_end(&mut body).await?;
225        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
226        Ok(response_message)
227    } else {
228        let mut body = Vec::new();
229        response.body_mut().read_to_end(&mut body).await?;
230        let body_str = std::str::from_utf8(&body)?;
231        Err(anyhow!(
232            "Failed to connect to API: {} {}",
233            response.status(),
234            body_str
235        ))
236    }
237}
238
239pub async fn stream_chat_completion(
240    client: &dyn HttpClient,
241    api_url: &str,
242    request: ChatRequest,
243    low_speed_timeout: Option<Duration>,
244) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
245    let uri = format!("{api_url}/api/chat");
246    let mut request_builder = HttpRequest::builder()
247        .method(Method::POST)
248        .uri(uri)
249        .header("Content-Type", "application/json");
250
251    if let Some(low_speed_timeout) = low_speed_timeout {
252        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
253    };
254
255    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
256    let mut response = client.send(request).await?;
257    if response.status().is_success() {
258        let reader = BufReader::new(response.into_body());
259
260        Ok(reader
261            .lines()
262            .filter_map(|line| async move {
263                match line {
264                    Ok(line) => {
265                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
266                    }
267                    Err(e) => Some(Err(e.into())),
268                }
269            })
270            .boxed())
271    } else {
272        let mut body = String::new();
273        response.body_mut().read_to_string(&mut body).await?;
274
275        Err(anyhow!(
276            "Failed to connect to Ollama API: {} {}",
277            response.status(),
278            body,
279        ))
280    }
281}
282
283pub async fn get_models(
284    client: &dyn HttpClient,
285    api_url: &str,
286    low_speed_timeout: Option<Duration>,
287) -> Result<Vec<LocalModelListing>> {
288    let uri = format!("{api_url}/api/tags");
289    let mut request_builder = HttpRequest::builder()
290        .method(Method::GET)
291        .uri(uri)
292        .header("Accept", "application/json");
293
294    if let Some(low_speed_timeout) = low_speed_timeout {
295        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
296    };
297
298    let request = request_builder.body(AsyncBody::default())?;
299
300    let mut response = client.send(request).await?;
301
302    let mut body = String::new();
303    response.body_mut().read_to_string(&mut body).await?;
304
305    if response.status().is_success() {
306        let response: LocalModelsResponse =
307            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
308
309        Ok(response.models)
310    } else {
311        Err(anyhow!(
312            "Failed to connect to Ollama API: {} {}",
313            response.status(),
314            body,
315        ))
316    }
317}
318
319/// Sends an empty request to Ollama to trigger loading the model
320pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
321    let uri = format!("{api_url}/api/generate");
322    let request = HttpRequest::builder()
323        .method(Method::POST)
324        .uri(uri)
325        .header("Content-Type", "application/json")
326        .body(AsyncBody::from(serde_json::to_string(
327            &serde_json::json!({
328                "model": model,
329                "keep_alive": "15m",
330            }),
331        )?))?;
332
333    let mut response = match client.send(request).await {
334        Ok(response) => response,
335        Err(err) => {
336            // Be ok with a timeout during preload of the model
337            if err.is_timeout() {
338                return Ok(());
339            } else {
340                return Err(err.into());
341            }
342        }
343    };
344
345    if response.status().is_success() {
346        Ok(())
347    } else {
348        let mut body = String::new();
349        response.body_mut().read_to_string(&mut body).await?;
350
351        Err(anyhow!(
352            "Failed to connect to Ollama API: {} {}",
353            response.status(),
354            body,
355        ))
356    }
357}