ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{http, AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  4use schemars::JsonSchema;
  5use serde::{Deserialize, Serialize};
  6use serde_json::{value::RawValue, Value};
  7use std::{convert::TryFrom, sync::Arc, time::Duration};
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            _ => Err(anyhow!("invalid role '{value}'")),
 28        }
 29    }
 30}
 31
 32impl From<Role> for String {
 33    fn from(val: Role) -> Self {
 34        match val {
 35            Role::User => "user".to_owned(),
 36            Role::Assistant => "assistant".to_owned(),
 37            Role::System => "system".to_owned(),
 38        }
 39    }
 40}
 41
 42#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 43#[serde(untagged)]
 44pub enum KeepAlive {
 45    /// Keep model alive for N seconds
 46    Seconds(isize),
 47    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 48    Duration(String),
 49}
 50
 51impl KeepAlive {
 52    /// Keep model alive until a new model is loaded or until Ollama shuts down
 53    fn indefinite() -> Self {
 54        Self::Seconds(-1)
 55    }
 56}
 57
 58impl Default for KeepAlive {
 59    fn default() -> Self {
 60        Self::indefinite()
 61    }
 62}
 63
 64#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 66pub struct Model {
 67    pub name: String,
 68    pub display_name: Option<String>,
 69    pub max_tokens: usize,
 70    pub keep_alive: Option<KeepAlive>,
 71}
 72
 73fn get_max_tokens(name: &str) -> usize {
 74    /// Default context length for unknown models.
 75    const DEFAULT_TOKENS: usize = 2048;
 76    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 77    const MAXIMUM_TOKENS: usize = 16384;
 78
 79    match name.split(':').next().unwrap() {
 80        "phi" | "tinyllama" | "granite-code" => 2048,
 81        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 82        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 83        "codellama" | "starcoder2" => 16384,
 84        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
 85        "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" | "yi-coder"
 86        | "llama3.2" | "qwen2.5-coder" => 128000,
 87        _ => DEFAULT_TOKENS,
 88    }
 89    .clamp(1, MAXIMUM_TOKENS)
 90}
 91
 92impl Model {
 93    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
 94        Self {
 95            name: name.to_owned(),
 96            display_name: display_name
 97                .map(ToString::to_string)
 98                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 99            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
100            keep_alive: Some(KeepAlive::indefinite()),
101        }
102    }
103
104    pub fn id(&self) -> &str {
105        &self.name
106    }
107
108    pub fn display_name(&self) -> &str {
109        self.display_name.as_ref().unwrap_or(&self.name)
110    }
111
112    pub fn max_token_count(&self) -> usize {
113        self.max_tokens
114    }
115}
116
117#[derive(Serialize, Deserialize, Debug)]
118#[serde(tag = "role", rename_all = "lowercase")]
119pub enum ChatMessage {
120    Assistant {
121        content: String,
122        tool_calls: Option<Vec<OllamaToolCall>>,
123    },
124    User {
125        content: String,
126    },
127    System {
128        content: String,
129    },
130}
131
132#[derive(Serialize, Deserialize, Debug)]
133#[serde(rename_all = "lowercase")]
134pub enum OllamaToolCall {
135    Function(OllamaFunctionCall),
136}
137
138#[derive(Serialize, Deserialize, Debug)]
139pub struct OllamaFunctionCall {
140    pub name: String,
141    pub arguments: Box<RawValue>,
142}
143
144#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
145pub struct OllamaFunctionTool {
146    pub name: String,
147    pub description: Option<String>,
148    pub parameters: Option<Value>,
149}
150
151#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
152#[serde(tag = "type", rename_all = "lowercase")]
153pub enum OllamaTool {
154    Function { function: OllamaFunctionTool },
155}
156
157#[derive(Serialize, Debug)]
158pub struct ChatRequest {
159    pub model: String,
160    pub messages: Vec<ChatMessage>,
161    pub stream: bool,
162    pub keep_alive: KeepAlive,
163    pub options: Option<ChatOptions>,
164    pub tools: Vec<OllamaTool>,
165}
166
167impl ChatRequest {
168    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
169        self.stream = false;
170        self.tools = tools;
171        self
172    }
173}
174
175// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
176#[derive(Serialize, Default, Debug)]
177pub struct ChatOptions {
178    pub num_ctx: Option<usize>,
179    pub num_predict: Option<isize>,
180    pub stop: Option<Vec<String>>,
181    pub temperature: Option<f32>,
182    pub top_p: Option<f32>,
183}
184
185#[derive(Deserialize, Debug)]
186pub struct ChatResponseDelta {
187    #[allow(unused)]
188    pub model: String,
189    #[allow(unused)]
190    pub created_at: String,
191    pub message: ChatMessage,
192    #[allow(unused)]
193    pub done_reason: Option<String>,
194    #[allow(unused)]
195    pub done: bool,
196}
197
198#[derive(Serialize, Deserialize)]
199pub struct LocalModelsResponse {
200    pub models: Vec<LocalModelListing>,
201}
202
203#[derive(Serialize, Deserialize)]
204pub struct LocalModelListing {
205    pub name: String,
206    pub modified_at: String,
207    pub size: u64,
208    pub digest: String,
209    pub details: ModelDetails,
210}
211
212#[derive(Serialize, Deserialize)]
213pub struct LocalModel {
214    pub modelfile: String,
215    pub parameters: String,
216    pub template: String,
217    pub details: ModelDetails,
218}
219
220#[derive(Serialize, Deserialize)]
221pub struct ModelDetails {
222    pub format: String,
223    pub family: String,
224    pub families: Option<Vec<String>>,
225    pub parameter_size: String,
226    pub quantization_level: String,
227}
228
229pub async fn complete(
230    client: &dyn HttpClient,
231    api_url: &str,
232    request: ChatRequest,
233) -> Result<ChatResponseDelta> {
234    let uri = format!("{api_url}/api/chat");
235    let request_builder = HttpRequest::builder()
236        .method(Method::POST)
237        .uri(uri)
238        .header("Content-Type", "application/json");
239
240    let serialized_request = serde_json::to_string(&request)?;
241    let request = request_builder.body(AsyncBody::from(serialized_request))?;
242
243    let mut response = client.send(request).await?;
244    if response.status().is_success() {
245        let mut body = Vec::new();
246        response.body_mut().read_to_end(&mut body).await?;
247        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
248        Ok(response_message)
249    } else {
250        let mut body = Vec::new();
251        response.body_mut().read_to_end(&mut body).await?;
252        let body_str = std::str::from_utf8(&body)?;
253        Err(anyhow!(
254            "Failed to connect to API: {} {}",
255            response.status(),
256            body_str
257        ))
258    }
259}
260
261pub async fn stream_chat_completion(
262    client: &dyn HttpClient,
263    api_url: &str,
264    request: ChatRequest,
265    low_speed_timeout: Option<Duration>,
266) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
267    let uri = format!("{api_url}/api/chat");
268    let mut request_builder = http::Request::builder()
269        .method(Method::POST)
270        .uri(uri)
271        .header("Content-Type", "application/json");
272
273    if let Some(low_speed_timeout) = low_speed_timeout {
274        request_builder = request_builder.read_timeout(low_speed_timeout);
275    }
276
277    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
278    let mut response = client.send(request).await?;
279    if response.status().is_success() {
280        let reader = BufReader::new(response.into_body());
281
282        Ok(reader
283            .lines()
284            .filter_map(|line| async move {
285                match line {
286                    Ok(line) => {
287                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
288                    }
289                    Err(e) => Some(Err(e.into())),
290                }
291            })
292            .boxed())
293    } else {
294        let mut body = String::new();
295        response.body_mut().read_to_string(&mut body).await?;
296
297        Err(anyhow!(
298            "Failed to connect to Ollama API: {} {}",
299            response.status(),
300            body,
301        ))
302    }
303}
304
305pub async fn get_models(
306    client: &dyn HttpClient,
307    api_url: &str,
308    _: Option<Duration>,
309) -> Result<Vec<LocalModelListing>> {
310    let uri = format!("{api_url}/api/tags");
311    let request_builder = HttpRequest::builder()
312        .method(Method::GET)
313        .uri(uri)
314        .header("Accept", "application/json");
315
316    let request = request_builder.body(AsyncBody::default())?;
317
318    let mut response = client.send(request).await?;
319
320    let mut body = String::new();
321    response.body_mut().read_to_string(&mut body).await?;
322
323    if response.status().is_success() {
324        let response: LocalModelsResponse =
325            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
326
327        Ok(response.models)
328    } else {
329        Err(anyhow!(
330            "Failed to connect to Ollama API: {} {}",
331            response.status(),
332            body,
333        ))
334    }
335}
336
337/// Sends an empty request to Ollama to trigger loading the model
338pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
339    let uri = format!("{api_url}/api/generate");
340    let request = HttpRequest::builder()
341        .method(Method::POST)
342        .uri(uri)
343        .header("Content-Type", "application/json")
344        .body(AsyncBody::from(serde_json::to_string(
345            &serde_json::json!({
346                "model": model,
347                "keep_alive": "15m",
348            }),
349        )?))?;
350
351    let mut response = client.send(request).await?;
352
353    if response.status().is_success() {
354        Ok(())
355    } else {
356        let mut body = String::new();
357        response.body_mut().read_to_string(&mut body).await?;
358
359        Err(anyhow!(
360            "Failed to connect to Ollama API: {} {}",
361            response.status(),
362            body,
363        ))
364    }
365}