ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use std::{convert::TryFrom, time::Duration};
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            _ => Err(anyhow!("invalid role '{value}'")),
 28        }
 29    }
 30}
 31
 32impl From<Role> for String {
 33    fn from(val: Role) -> Self {
 34        match val {
 35            Role::User => "user".to_owned(),
 36            Role::Assistant => "assistant".to_owned(),
 37            Role::System => "system".to_owned(),
 38        }
 39    }
 40}
 41
 42#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 43#[serde(untagged)]
 44pub enum KeepAlive {
 45    /// Keep model alive for N seconds
 46    Seconds(isize),
 47    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 48    Duration(String),
 49}
 50
 51impl KeepAlive {
 52    /// Keep model alive until a new model is loaded or until Ollama shuts down
 53    fn indefinite() -> Self {
 54        Self::Seconds(-1)
 55    }
 56}
 57
 58impl Default for KeepAlive {
 59    fn default() -> Self {
 60        Self::indefinite()
 61    }
 62}
 63
 64#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 66pub struct Model {
 67    pub name: String,
 68    pub max_tokens: usize,
 69    pub keep_alive: Option<KeepAlive>,
 70}
 71
 72impl Model {
 73    pub fn new(name: &str) -> Self {
 74        Self {
 75            name: name.to_owned(),
 76            max_tokens: 2048,
 77            keep_alive: Some(KeepAlive::indefinite()),
 78        }
 79    }
 80
 81    pub fn id(&self) -> &str {
 82        &self.name
 83    }
 84
 85    pub fn display_name(&self) -> &str {
 86        &self.name
 87    }
 88
 89    pub fn max_token_count(&self) -> usize {
 90        self.max_tokens
 91    }
 92}
 93
 94#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 95#[serde(tag = "role", rename_all = "lowercase")]
 96pub enum ChatMessage {
 97    Assistant { content: String },
 98    User { content: String },
 99    System { content: String },
100}
101
102#[derive(Serialize)]
103pub struct ChatRequest {
104    pub model: String,
105    pub messages: Vec<ChatMessage>,
106    pub stream: bool,
107    pub keep_alive: KeepAlive,
108    pub options: Option<ChatOptions>,
109}
110
111// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
112#[derive(Serialize, Default)]
113pub struct ChatOptions {
114    pub num_ctx: Option<usize>,
115    pub num_predict: Option<isize>,
116    pub stop: Option<Vec<String>>,
117    pub temperature: Option<f32>,
118    pub top_p: Option<f32>,
119}
120
121#[derive(Deserialize)]
122pub struct ChatResponseDelta {
123    #[allow(unused)]
124    pub model: String,
125    #[allow(unused)]
126    pub created_at: String,
127    pub message: ChatMessage,
128    #[allow(unused)]
129    pub done_reason: Option<String>,
130    #[allow(unused)]
131    pub done: bool,
132}
133
134#[derive(Serialize, Deserialize)]
135pub struct LocalModelsResponse {
136    pub models: Vec<LocalModelListing>,
137}
138
139#[derive(Serialize, Deserialize)]
140pub struct LocalModelListing {
141    pub name: String,
142    pub modified_at: String,
143    pub size: u64,
144    pub digest: String,
145    pub details: ModelDetails,
146}
147
148#[derive(Serialize, Deserialize)]
149pub struct LocalModel {
150    pub modelfile: String,
151    pub parameters: String,
152    pub template: String,
153    pub details: ModelDetails,
154}
155
156#[derive(Serialize, Deserialize)]
157pub struct ModelDetails {
158    pub format: String,
159    pub family: String,
160    pub families: Option<Vec<String>>,
161    pub parameter_size: String,
162    pub quantization_level: String,
163}
164
165pub async fn stream_chat_completion(
166    client: &dyn HttpClient,
167    api_url: &str,
168    request: ChatRequest,
169    low_speed_timeout: Option<Duration>,
170) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
171    let uri = format!("{api_url}/api/chat");
172    let mut request_builder = HttpRequest::builder()
173        .method(Method::POST)
174        .uri(uri)
175        .header("Content-Type", "application/json");
176
177    if let Some(low_speed_timeout) = low_speed_timeout {
178        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
179    };
180
181    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
182    let mut response = client.send(request).await?;
183    if response.status().is_success() {
184        let reader = BufReader::new(response.into_body());
185
186        Ok(reader
187            .lines()
188            .filter_map(|line| async move {
189                match line {
190                    Ok(line) => {
191                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
192                    }
193                    Err(e) => Some(Err(e.into())),
194                }
195            })
196            .boxed())
197    } else {
198        let mut body = String::new();
199        response.body_mut().read_to_string(&mut body).await?;
200
201        Err(anyhow!(
202            "Failed to connect to Ollama API: {} {}",
203            response.status(),
204            body,
205        ))
206    }
207}
208
209pub async fn get_models(
210    client: &dyn HttpClient,
211    api_url: &str,
212    low_speed_timeout: Option<Duration>,
213) -> Result<Vec<LocalModelListing>> {
214    let uri = format!("{api_url}/api/tags");
215    let mut request_builder = HttpRequest::builder()
216        .method(Method::GET)
217        .uri(uri)
218        .header("Accept", "application/json");
219
220    if let Some(low_speed_timeout) = low_speed_timeout {
221        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
222    };
223
224    let request = request_builder.body(AsyncBody::default())?;
225
226    let mut response = client.send(request).await?;
227
228    let mut body = String::new();
229    response.body_mut().read_to_string(&mut body).await?;
230
231    if response.status().is_success() {
232        let response: LocalModelsResponse =
233            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
234
235        Ok(response.models)
236    } else {
237        Err(anyhow!(
238            "Failed to connect to Ollama API: {} {}",
239            response.status(),
240            body,
241        ))
242    }
243}
244
245/// Sends an empty request to Ollama to trigger loading the model
246pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
247    let uri = format!("{api_url}/api/generate");
248    let request = HttpRequest::builder()
249        .method(Method::POST)
250        .uri(uri)
251        .header("Content-Type", "application/json")
252        .body(AsyncBody::from(serde_json::to_string(
253            &serde_json::json!({
254                "model": model,
255                "keep_alive": "15m",
256            }),
257        )?))?;
258
259    let mut response = match client.send(request).await {
260        Ok(response) => response,
261        Err(err) => {
262            // Be ok with a timeout during preload of the model
263            if err.is_timeout() {
264                return Ok(());
265            } else {
266                return Err(err.into());
267            }
268        }
269    };
270
271    if response.status().is_success() {
272        Ok(())
273    } else {
274        let mut body = String::new();
275        response.body_mut().read_to_string(&mut body).await?;
276
277        Err(anyhow!(
278            "Failed to connect to Ollama API: {} {}",
279            response.status(),
280            body,
281        ))
282    }
283}