ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use serde_json::{value::RawValue, Value};
  8use std::{convert::TryFrom, sync::Arc, time::Duration};
  9
 10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            _ => Err(anyhow!("invalid role '{value}'")),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39        }
 40    }
 41}
 42
 43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 44#[serde(untagged)]
 45pub enum KeepAlive {
 46    /// Keep model alive for N seconds
 47    Seconds(isize),
 48    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 49    Duration(String),
 50}
 51
 52impl KeepAlive {
 53    /// Keep model alive until a new model is loaded or until Ollama shuts down
 54    fn indefinite() -> Self {
 55        Self::Seconds(-1)
 56    }
 57}
 58
 59impl Default for KeepAlive {
 60    fn default() -> Self {
 61        Self::indefinite()
 62    }
 63}
 64
 65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 67pub struct Model {
 68    pub name: String,
 69    pub display_name: Option<String>,
 70    pub max_tokens: usize,
 71    pub keep_alive: Option<KeepAlive>,
 72}
 73
 74fn get_max_tokens(name: &str) -> usize {
 75    /// Default context length for unknown models.
 76    const DEFAULT_TOKENS: usize = 2048;
 77    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 78    const MAXIMUM_TOKENS: usize = 16384;
 79
 80    match name.split(':').next().unwrap() {
 81        "phi" | "tinyllama" | "granite-code" => 2048,
 82        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 83        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 84        "codellama" | "starcoder2" => 16384,
 85        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
 86        "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" => 128000,
 87        _ => DEFAULT_TOKENS,
 88    }
 89    .clamp(1, MAXIMUM_TOKENS)
 90}
 91
 92impl Model {
 93    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
 94        Self {
 95            name: name.to_owned(),
 96            display_name: display_name
 97                .map(ToString::to_string)
 98                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 99            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
100            keep_alive: Some(KeepAlive::indefinite()),
101        }
102    }
103
104    pub fn id(&self) -> &str {
105        &self.name
106    }
107
108    pub fn display_name(&self) -> &str {
109        self.display_name.as_ref().unwrap_or(&self.name)
110    }
111
112    pub fn max_token_count(&self) -> usize {
113        self.max_tokens
114    }
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118#[serde(tag = "role", rename_all = "lowercase")]
119pub enum ChatMessage {
120    Assistant {
121        content: String,
122        tool_calls: Option<Vec<OllamaToolCall>>,
123    },
124    User {
125        content: String,
126    },
127    System {
128        content: String,
129    },
130}
131
132#[derive(Serialize, Deserialize, Debug)]
133#[serde(rename_all = "lowercase")]
134pub enum OllamaToolCall {
135    Function(OllamaFunctionCall),
136}
137
138#[derive(Serialize, Deserialize, Debug)]
139pub struct OllamaFunctionCall {
140    pub name: String,
141    pub arguments: Box<RawValue>,
142}
143
144#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
145pub struct OllamaFunctionTool {
146    pub name: String,
147    pub description: Option<String>,
148    pub parameters: Option<Value>,
149}
150
151#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
152#[serde(tag = "type", rename_all = "lowercase")]
153pub enum OllamaTool {
154    Function { function: OllamaFunctionTool },
155}
156
157#[derive(Serialize, Debug)]
158pub struct ChatRequest {
159    pub model: String,
160    pub messages: Vec<ChatMessage>,
161    pub stream: bool,
162    pub keep_alive: KeepAlive,
163    pub options: Option<ChatOptions>,
164    pub tools: Vec<OllamaTool>,
165}
166
167impl ChatRequest {
168    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
169        self.stream = false;
170        self.tools = tools;
171        self
172    }
173}
174
175// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
176#[derive(Serialize, Default, Debug)]
177pub struct ChatOptions {
178    pub num_ctx: Option<usize>,
179    pub num_predict: Option<isize>,
180    pub stop: Option<Vec<String>>,
181    pub temperature: Option<f32>,
182    pub top_p: Option<f32>,
183}
184
185#[derive(Deserialize, Debug)]
186pub struct ChatResponseDelta {
187    #[allow(unused)]
188    pub model: String,
189    #[allow(unused)]
190    pub created_at: String,
191    pub message: ChatMessage,
192    #[allow(unused)]
193    pub done_reason: Option<String>,
194    #[allow(unused)]
195    pub done: bool,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct LocalModelsResponse {
200    pub models: Vec<LocalModelListing>,
201}
202
203#[derive(Serialize, Deserialize)]
204pub struct LocalModelListing {
205    pub name: String,
206    pub modified_at: String,
207    pub size: u64,
208    pub digest: String,
209    pub details: ModelDetails,
210}
211
212#[derive(Serialize, Deserialize)]
213pub struct LocalModel {
214    pub modelfile: String,
215    pub parameters: String,
216    pub template: String,
217    pub details: ModelDetails,
218}
219
220#[derive(Serialize, Deserialize)]
221pub struct ModelDetails {
222    pub format: String,
223    pub family: String,
224    pub families: Option<Vec<String>>,
225    pub parameter_size: String,
226    pub quantization_level: String,
227}
228
229pub async fn complete(
230    client: &dyn HttpClient,
231    api_url: &str,
232    request: ChatRequest,
233) -> Result<ChatResponseDelta> {
234    let uri = format!("{api_url}/api/chat");
235    let request_builder = HttpRequest::builder()
236        .method(Method::POST)
237        .uri(uri)
238        .header("Content-Type", "application/json");
239
240    let serialized_request = serde_json::to_string(&request)?;
241    let request = request_builder.body(AsyncBody::from(serialized_request))?;
242
243    let mut response = client.send(request).await?;
244    if response.status().is_success() {
245        let mut body = Vec::new();
246        response.body_mut().read_to_end(&mut body).await?;
247        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
248        Ok(response_message)
249    } else {
250        let mut body = Vec::new();
251        response.body_mut().read_to_end(&mut body).await?;
252        let body_str = std::str::from_utf8(&body)?;
253        Err(anyhow!(
254            "Failed to connect to API: {} {}",
255            response.status(),
256            body_str
257        ))
258    }
259}
260
261pub async fn stream_chat_completion(
262    client: &dyn HttpClient,
263    api_url: &str,
264    request: ChatRequest,
265    low_speed_timeout: Option<Duration>,
266) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
267    let uri = format!("{api_url}/api/chat");
268    let mut request_builder = HttpRequest::builder()
269        .method(Method::POST)
270        .uri(uri)
271        .header("Content-Type", "application/json");
272
273    if let Some(low_speed_timeout) = low_speed_timeout {
274        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
275    };
276
277    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
278    let mut response = client.send(request).await?;
279    if response.status().is_success() {
280        let reader = BufReader::new(response.into_body());
281
282        Ok(reader
283            .lines()
284            .filter_map(|line| async move {
285                match line {
286                    Ok(line) => {
287                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
288                    }
289                    Err(e) => Some(Err(e.into())),
290                }
291            })
292            .boxed())
293    } else {
294        let mut body = String::new();
295        response.body_mut().read_to_string(&mut body).await?;
296
297        Err(anyhow!(
298            "Failed to connect to Ollama API: {} {}",
299            response.status(),
300            body,
301        ))
302    }
303}
304
305pub async fn get_models(
306    client: &dyn HttpClient,
307    api_url: &str,
308    low_speed_timeout: Option<Duration>,
309) -> Result<Vec<LocalModelListing>> {
310    let uri = format!("{api_url}/api/tags");
311    let mut request_builder = HttpRequest::builder()
312        .method(Method::GET)
313        .uri(uri)
314        .header("Accept", "application/json");
315
316    if let Some(low_speed_timeout) = low_speed_timeout {
317        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
318    };
319
320    let request = request_builder.body(AsyncBody::default())?;
321
322    let mut response = client.send(request).await?;
323
324    let mut body = String::new();
325    response.body_mut().read_to_string(&mut body).await?;
326
327    if response.status().is_success() {
328        let response: LocalModelsResponse =
329            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
330
331        Ok(response.models)
332    } else {
333        Err(anyhow!(
334            "Failed to connect to Ollama API: {} {}",
335            response.status(),
336            body,
337        ))
338    }
339}
340
341/// Sends an empty request to Ollama to trigger loading the model
342pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
343    let uri = format!("{api_url}/api/generate");
344    let request = HttpRequest::builder()
345        .method(Method::POST)
346        .uri(uri)
347        .header("Content-Type", "application/json")
348        .body(AsyncBody::from(serde_json::to_string(
349            &serde_json::json!({
350                "model": model,
351                "keep_alive": "15m",
352            }),
353        )?))?;
354
355    let mut response = match client.send(request).await {
356        Ok(response) => response,
357        Err(err) => {
358            // Be ok with a timeout during preload of the model
359            if err.is_timeout() {
360                return Ok(());
361            } else {
362                return Err(err.into());
363            }
364        }
365    };
366
367    if response.status().is_success() {
368        Ok(())
369    } else {
370        let mut body = String::new();
371        response.body_mut().read_to_string(&mut body).await?;
372
373        Err(anyhow!(
374            "Failed to connect to Ollama API: {} {}",
375            response.status(),
376            body,
377        ))
378    }
379}