ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use std::{convert::TryFrom, time::Duration};
  7
  8pub const OLLAMA_API_URL: &str = "http://localhost:11434";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16}
 17
 18impl TryFrom<String> for Role {
 19    type Error = anyhow::Error;
 20
 21    fn try_from(value: String) -> Result<Self> {
 22        match value.as_str() {
 23            "user" => Ok(Self::User),
 24            "assistant" => Ok(Self::Assistant),
 25            "system" => Ok(Self::System),
 26            _ => Err(anyhow!("invalid role '{value}'")),
 27        }
 28    }
 29}
 30
 31impl From<Role> for String {
 32    fn from(val: Role) -> Self {
 33        match val {
 34            Role::User => "user".to_owned(),
 35            Role::Assistant => "assistant".to_owned(),
 36            Role::System => "system".to_owned(),
 37        }
 38    }
 39}
 40
 41#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 42#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 43pub struct Model {
 44    pub name: String,
 45    pub max_tokens: usize,
 46    pub keep_alive: Option<String>,
 47}
 48
 49impl Model {
 50    pub fn new(name: &str) -> Self {
 51        Self {
 52            name: name.to_owned(),
 53            max_tokens: 2048,
 54            keep_alive: Some("10m".to_owned()),
 55        }
 56    }
 57
 58    pub fn id(&self) -> &str {
 59        &self.name
 60    }
 61
 62    pub fn display_name(&self) -> &str {
 63        &self.name
 64    }
 65
 66    pub fn max_token_count(&self) -> usize {
 67        self.max_tokens
 68    }
 69}
 70
 71#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 72#[serde(tag = "role", rename_all = "lowercase")]
 73pub enum ChatMessage {
 74    Assistant { content: String },
 75    User { content: String },
 76    System { content: String },
 77}
 78
 79#[derive(Serialize)]
 80pub struct ChatRequest {
 81    pub model: String,
 82    pub messages: Vec<ChatMessage>,
 83    pub stream: bool,
 84    pub keep_alive: Option<String>,
 85    pub options: Option<ChatOptions>,
 86}
 87
 88// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
 89#[derive(Serialize, Default)]
 90pub struct ChatOptions {
 91    pub num_ctx: Option<usize>,
 92    pub num_predict: Option<isize>,
 93    pub stop: Option<Vec<String>>,
 94    pub temperature: Option<f32>,
 95    pub top_p: Option<f32>,
 96}
 97
 98#[derive(Deserialize)]
 99pub struct ChatResponseDelta {
100    #[allow(unused)]
101    pub model: String,
102    #[allow(unused)]
103    pub created_at: String,
104    pub message: ChatMessage,
105    #[allow(unused)]
106    pub done_reason: Option<String>,
107    #[allow(unused)]
108    pub done: bool,
109}
110
111#[derive(Serialize, Deserialize)]
112pub struct LocalModelsResponse {
113    pub models: Vec<LocalModelListing>,
114}
115
116#[derive(Serialize, Deserialize)]
117pub struct LocalModelListing {
118    pub name: String,
119    pub modified_at: String,
120    pub size: u64,
121    pub digest: String,
122    pub details: ModelDetails,
123}
124
125#[derive(Serialize, Deserialize)]
126pub struct LocalModel {
127    pub modelfile: String,
128    pub parameters: String,
129    pub template: String,
130    pub details: ModelDetails,
131}
132
133#[derive(Serialize, Deserialize)]
134pub struct ModelDetails {
135    pub format: String,
136    pub family: String,
137    pub families: Option<Vec<String>>,
138    pub parameter_size: String,
139    pub quantization_level: String,
140}
141
142pub async fn stream_chat_completion(
143    client: &dyn HttpClient,
144    api_url: &str,
145    request: ChatRequest,
146    low_speed_timeout: Option<Duration>,
147) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
148    let uri = format!("{api_url}/api/chat");
149    let mut request_builder = HttpRequest::builder()
150        .method(Method::POST)
151        .uri(uri)
152        .header("Content-Type", "application/json");
153
154    if let Some(low_speed_timeout) = low_speed_timeout {
155        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
156    };
157
158    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
159    let mut response = client.send(request).await?;
160    if response.status().is_success() {
161        let reader = BufReader::new(response.into_body());
162
163        Ok(reader
164            .lines()
165            .filter_map(|line| async move {
166                match line {
167                    Ok(line) => {
168                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
169                    }
170                    Err(e) => Some(Err(e.into())),
171                }
172            })
173            .boxed())
174    } else {
175        let mut body = String::new();
176        response.body_mut().read_to_string(&mut body).await?;
177
178        Err(anyhow!(
179            "Failed to connect to Ollama API: {} {}",
180            response.status(),
181            body,
182        ))
183    }
184}
185
186pub async fn get_models(
187    client: &dyn HttpClient,
188    api_url: &str,
189    low_speed_timeout: Option<Duration>,
190) -> Result<Vec<LocalModelListing>> {
191    let uri = format!("{api_url}/api/tags");
192    let mut request_builder = HttpRequest::builder()
193        .method(Method::GET)
194        .uri(uri)
195        .header("Accept", "application/json");
196
197    if let Some(low_speed_timeout) = low_speed_timeout {
198        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
199    };
200
201    let request = request_builder.body(AsyncBody::default())?;
202
203    let mut response = client.send(request).await?;
204
205    let mut body = String::new();
206    response.body_mut().read_to_string(&mut body).await?;
207
208    if response.status().is_success() {
209        let response: LocalModelsResponse =
210            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
211
212        Ok(response.models)
213    } else {
214        Err(anyhow!(
215            "Failed to connect to Ollama API: {} {}",
216            response.status(),
217            body,
218        ))
219    }
220}
221
222/// Sends an empty request to Ollama to trigger loading the model
223pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
224    let uri = format!("{api_url}/api/generate");
225    let request = HttpRequest::builder()
226        .method(Method::POST)
227        .uri(uri)
228        .header("Content-Type", "application/json")
229        .body(AsyncBody::from(serde_json::to_string(
230            &serde_json::json!({
231                "model": model,
232                "keep_alive": "15m",
233            }),
234        )?))?;
235
236    let mut response = match client.send(request).await {
237        Ok(response) => response,
238        Err(err) => {
239            // Be ok with a timeout during preload of the model
240            if err.is_timeout() {
241                return Ok(());
242            } else {
243                return Err(err.into());
244            }
245        }
246    };
247
248    if response.status().is_success() {
249        Ok(())
250    } else {
251        let mut body = String::new();
252        response.body_mut().read_to_string(&mut body).await?;
253
254        Err(anyhow!(
255            "Failed to connect to Ollama API: {} {}",
256            response.status(),
257            body,
258        ))
259    }
260}