ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{http, AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use serde_json::{value::RawValue, Value};
  8use std::{convert::TryFrom, sync::Arc, time::Duration};
  9
 10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            _ => Err(anyhow!("invalid role '{value}'")),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39        }
 40    }
 41}
 42
 43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 44#[serde(untagged)]
 45pub enum KeepAlive {
 46    /// Keep model alive for N seconds
 47    Seconds(isize),
 48    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 49    Duration(String),
 50}
 51
 52impl KeepAlive {
 53    /// Keep model alive until a new model is loaded or until Ollama shuts down
 54    fn indefinite() -> Self {
 55        Self::Seconds(-1)
 56    }
 57}
 58
 59impl Default for KeepAlive {
 60    fn default() -> Self {
 61        Self::indefinite()
 62    }
 63}
 64
 65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 67pub struct Model {
 68    pub name: String,
 69    pub display_name: Option<String>,
 70    pub max_tokens: usize,
 71    pub keep_alive: Option<KeepAlive>,
 72}
 73
 74fn get_max_tokens(name: &str) -> usize {
 75    /// Default context length for unknown models.
 76    const DEFAULT_TOKENS: usize = 2048;
 77    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 78    const MAXIMUM_TOKENS: usize = 16384;
 79
 80    match name.split(':').next().unwrap() {
 81        "phi" | "tinyllama" | "granite-code" => 2048,
 82        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 83        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 84        "codellama" | "starcoder2" => 16384,
 85        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
 86        "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" | "yi-coder"
 87        | "qwen2.5-coder" => 128000,
 88        _ => DEFAULT_TOKENS,
 89    }
 90    .clamp(1, MAXIMUM_TOKENS)
 91}
 92
 93impl Model {
 94    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
 95        Self {
 96            name: name.to_owned(),
 97            display_name: display_name
 98                .map(ToString::to_string)
 99                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
100            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
101            keep_alive: Some(KeepAlive::indefinite()),
102        }
103    }
104
105    pub fn id(&self) -> &str {
106        &self.name
107    }
108
109    pub fn display_name(&self) -> &str {
110        self.display_name.as_ref().unwrap_or(&self.name)
111    }
112
113    pub fn max_token_count(&self) -> usize {
114        self.max_tokens
115    }
116}
117
118#[derive(Serialize, Deserialize, Debug)]
119#[serde(tag = "role", rename_all = "lowercase")]
120pub enum ChatMessage {
121    Assistant {
122        content: String,
123        tool_calls: Option<Vec<OllamaToolCall>>,
124    },
125    User {
126        content: String,
127    },
128    System {
129        content: String,
130    },
131}
132
133#[derive(Serialize, Deserialize, Debug)]
134#[serde(rename_all = "lowercase")]
135pub enum OllamaToolCall {
136    Function(OllamaFunctionCall),
137}
138
139#[derive(Serialize, Deserialize, Debug)]
140pub struct OllamaFunctionCall {
141    pub name: String,
142    pub arguments: Box<RawValue>,
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146pub struct OllamaFunctionTool {
147    pub name: String,
148    pub description: Option<String>,
149    pub parameters: Option<Value>,
150}
151
152#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
153#[serde(tag = "type", rename_all = "lowercase")]
154pub enum OllamaTool {
155    Function { function: OllamaFunctionTool },
156}
157
158#[derive(Serialize, Debug)]
159pub struct ChatRequest {
160    pub model: String,
161    pub messages: Vec<ChatMessage>,
162    pub stream: bool,
163    pub keep_alive: KeepAlive,
164    pub options: Option<ChatOptions>,
165    pub tools: Vec<OllamaTool>,
166}
167
168impl ChatRequest {
169    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
170        self.stream = false;
171        self.tools = tools;
172        self
173    }
174}
175
176// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
177#[derive(Serialize, Default, Debug)]
178pub struct ChatOptions {
179    pub num_ctx: Option<usize>,
180    pub num_predict: Option<isize>,
181    pub stop: Option<Vec<String>>,
182    pub temperature: Option<f32>,
183    pub top_p: Option<f32>,
184}
185
186#[derive(Deserialize, Debug)]
187pub struct ChatResponseDelta {
188    #[allow(unused)]
189    pub model: String,
190    #[allow(unused)]
191    pub created_at: String,
192    pub message: ChatMessage,
193    #[allow(unused)]
194    pub done_reason: Option<String>,
195    #[allow(unused)]
196    pub done: bool,
197}
198
199#[derive(Serialize, Deserialize)]
200pub struct LocalModelsResponse {
201    pub models: Vec<LocalModelListing>,
202}
203
204#[derive(Serialize, Deserialize)]
205pub struct LocalModelListing {
206    pub name: String,
207    pub modified_at: String,
208    pub size: u64,
209    pub digest: String,
210    pub details: ModelDetails,
211}
212
213#[derive(Serialize, Deserialize)]
214pub struct LocalModel {
215    pub modelfile: String,
216    pub parameters: String,
217    pub template: String,
218    pub details: ModelDetails,
219}
220
221#[derive(Serialize, Deserialize)]
222pub struct ModelDetails {
223    pub format: String,
224    pub family: String,
225    pub families: Option<Vec<String>>,
226    pub parameter_size: String,
227    pub quantization_level: String,
228}
229
230pub async fn complete(
231    client: &dyn HttpClient,
232    api_url: &str,
233    request: ChatRequest,
234) -> Result<ChatResponseDelta> {
235    let uri = format!("{api_url}/api/chat");
236    let request_builder = HttpRequest::builder()
237        .method(Method::POST)
238        .uri(uri)
239        .header("Content-Type", "application/json");
240
241    let serialized_request = serde_json::to_string(&request)?;
242    let request = request_builder.body(AsyncBody::from(serialized_request))?;
243
244    let mut response = client.send(request).await?;
245    if response.status().is_success() {
246        let mut body = Vec::new();
247        response.body_mut().read_to_end(&mut body).await?;
248        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
249        Ok(response_message)
250    } else {
251        let mut body = Vec::new();
252        response.body_mut().read_to_end(&mut body).await?;
253        let body_str = std::str::from_utf8(&body)?;
254        Err(anyhow!(
255            "Failed to connect to API: {} {}",
256            response.status(),
257            body_str
258        ))
259    }
260}
261
262pub async fn stream_chat_completion(
263    client: &dyn HttpClient,
264    api_url: &str,
265    request: ChatRequest,
266    low_speed_timeout: Option<Duration>,
267) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
268    let uri = format!("{api_url}/api/chat");
269    let mut request_builder = http::Request::builder()
270        .method(Method::POST)
271        .uri(uri)
272        .header("Content-Type", "application/json");
273
274    if let Some(low_speed_timeout) = low_speed_timeout {
275        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
276    };
277
278    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
279    let mut response = client.send(request).await?;
280    if response.status().is_success() {
281        let reader = BufReader::new(response.into_body());
282
283        Ok(reader
284            .lines()
285            .filter_map(|line| async move {
286                match line {
287                    Ok(line) => {
288                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
289                    }
290                    Err(e) => Some(Err(e.into())),
291                }
292            })
293            .boxed())
294    } else {
295        let mut body = String::new();
296        response.body_mut().read_to_string(&mut body).await?;
297
298        Err(anyhow!(
299            "Failed to connect to Ollama API: {} {}",
300            response.status(),
301            body,
302        ))
303    }
304}
305
306pub async fn get_models(
307    client: &dyn HttpClient,
308    api_url: &str,
309    _: Option<Duration>,
310) -> Result<Vec<LocalModelListing>> {
311    let uri = format!("{api_url}/api/tags");
312    let request_builder = HttpRequest::builder()
313        .method(Method::GET)
314        .uri(uri)
315        .header("Accept", "application/json");
316
317    let request = request_builder.body(AsyncBody::default())?;
318
319    let mut response = client.send(request).await?;
320
321    let mut body = String::new();
322    response.body_mut().read_to_string(&mut body).await?;
323
324    if response.status().is_success() {
325        let response: LocalModelsResponse =
326            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
327
328        Ok(response.models)
329    } else {
330        Err(anyhow!(
331            "Failed to connect to Ollama API: {} {}",
332            response.status(),
333            body,
334        ))
335    }
336}
337
338/// Sends an empty request to Ollama to trigger loading the model
339pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
340    let uri = format!("{api_url}/api/generate");
341    let request = HttpRequest::builder()
342        .method(Method::POST)
343        .uri(uri)
344        .header("Content-Type", "application/json")
345        .body(AsyncBody::from(serde_json::to_string(
346            &serde_json::json!({
347                "model": model,
348                "keep_alive": "15m",
349            }),
350        )?))?;
351
352    let mut response = client.send(request).await?;
353
354    if response.status().is_success() {
355        Ok(())
356    } else {
357        let mut body = String::new();
358        response.body_mut().read_to_string(&mut body).await?;
359
360        Err(anyhow!(
361            "Failed to connect to Ollama API: {} {}",
362            response.status(),
363            body,
364        ))
365    }
366}