ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use serde_json::{value::RawValue, Value};
  8use std::{convert::TryFrom, sync::Arc, time::Duration};
  9
 10pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            _ => Err(anyhow!("invalid role '{value}'")),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39        }
 40    }
 41}
 42
 43#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 44#[serde(untagged)]
 45pub enum KeepAlive {
 46    /// Keep model alive for N seconds
 47    Seconds(isize),
 48    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 49    Duration(String),
 50}
 51
 52impl KeepAlive {
 53    /// Keep model alive until a new model is loaded or until Ollama shuts down
 54    fn indefinite() -> Self {
 55        Self::Seconds(-1)
 56    }
 57}
 58
 59impl Default for KeepAlive {
 60    fn default() -> Self {
 61        Self::indefinite()
 62    }
 63}
 64
 65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 66#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 67pub struct Model {
 68    pub name: String,
 69    pub max_tokens: usize,
 70    pub keep_alive: Option<KeepAlive>,
 71}
 72
 73// This could be dynamically retrieved via the API (1 call per model)
 74// curl -s http://localhost:11434/api/show -d '{"model": "llama3.1:latest"}' | jq '.model_info."llama.context_length"'
 75fn get_max_tokens(name: &str) -> usize {
 76    match name {
 77        "dolphin-llama3:8b-256k" => 262144, // 256K
 78        _ => match name.split(':').next().unwrap() {
 79            "mistral-nemo" => 1024000,                                      // 1M
 80            "deepseek-coder-v2" => 163840,                                  // 160K
 81            "llama3.1" | "phi3" | "command-r" | "command-r-plus" => 131072, // 128K
 82            "codeqwen" => 65536,                                            // 64K
 83            "mistral" | "mistral-large" | "dolphin-mistral" | "codestral"   // 32K
 84            | "mistral-openorca" | "dolphin-mixtral" | "mixstral" | "llava"
 85            | "qwen" | "qwen2" | "wizardlm2" | "wizard-math" => 32768,
 86            "codellama" | "stable-code" | "deepseek-coder" | "starcoder2"   // 16K
 87            | "wizardcoder" => 16384,
 88            "llama3" | "gemma2" | "gemma" | "codegemma" | "dolphin-llama3"  // 8K
 89            | "llava-llama3" | "starcoder" | "openchat" | "aya" => 8192,
 90            "llama2" | "yi" | "llama2-chinese" | "vicuna" | "nous-hermes2"  // 4K
 91            | "stablelm2" => 4096,
 92            "phi" | "orca-mini" | "tinyllama" | "granite-code" => 2048,     // 2K
 93            _ => 2048,                                                      // 2K (default)
 94        },
 95    }
 96}
 97
 98impl Model {
 99    pub fn new(name: &str) -> Self {
100        Self {
101            name: name.to_owned(),
102            max_tokens: get_max_tokens(name),
103            keep_alive: Some(KeepAlive::indefinite()),
104        }
105    }
106
107    pub fn id(&self) -> &str {
108        &self.name
109    }
110
111    pub fn display_name(&self) -> &str {
112        &self.name
113    }
114
115    pub fn max_token_count(&self) -> usize {
116        self.max_tokens
117    }
118}
119
120#[derive(Serialize, Deserialize, Debug)]
121#[serde(tag = "role", rename_all = "lowercase")]
122pub enum ChatMessage {
123    Assistant {
124        content: String,
125        tool_calls: Option<Vec<OllamaToolCall>>,
126    },
127    User {
128        content: String,
129    },
130    System {
131        content: String,
132    },
133}
134
135#[derive(Serialize, Deserialize, Debug)]
136#[serde(rename_all = "lowercase")]
137pub enum OllamaToolCall {
138    Function(OllamaFunctionCall),
139}
140
141#[derive(Serialize, Deserialize, Debug)]
142pub struct OllamaFunctionCall {
143    pub name: String,
144    pub arguments: Box<RawValue>,
145}
146
147#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
148pub struct OllamaFunctionTool {
149    pub name: String,
150    pub description: Option<String>,
151    pub parameters: Option<Value>,
152}
153
154#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
155#[serde(tag = "type", rename_all = "lowercase")]
156pub enum OllamaTool {
157    Function { function: OllamaFunctionTool },
158}
159
160#[derive(Serialize, Debug)]
161pub struct ChatRequest {
162    pub model: String,
163    pub messages: Vec<ChatMessage>,
164    pub stream: bool,
165    pub keep_alive: KeepAlive,
166    pub options: Option<ChatOptions>,
167    pub tools: Vec<OllamaTool>,
168}
169
170impl ChatRequest {
171    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
172        self.stream = false;
173        self.tools = tools;
174        self
175    }
176}
177
178// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
179#[derive(Serialize, Default, Debug)]
180pub struct ChatOptions {
181    pub num_ctx: Option<usize>,
182    pub num_predict: Option<isize>,
183    pub stop: Option<Vec<String>>,
184    pub temperature: Option<f32>,
185    pub top_p: Option<f32>,
186}
187
188#[derive(Deserialize, Debug)]
189pub struct ChatResponseDelta {
190    #[allow(unused)]
191    pub model: String,
192    #[allow(unused)]
193    pub created_at: String,
194    pub message: ChatMessage,
195    #[allow(unused)]
196    pub done_reason: Option<String>,
197    #[allow(unused)]
198    pub done: bool,
199}
200
201#[derive(Serialize, Deserialize)]
202pub struct LocalModelsResponse {
203    pub models: Vec<LocalModelListing>,
204}
205
206#[derive(Serialize, Deserialize)]
207pub struct LocalModelListing {
208    pub name: String,
209    pub modified_at: String,
210    pub size: u64,
211    pub digest: String,
212    pub details: ModelDetails,
213}
214
215#[derive(Serialize, Deserialize)]
216pub struct LocalModel {
217    pub modelfile: String,
218    pub parameters: String,
219    pub template: String,
220    pub details: ModelDetails,
221}
222
223#[derive(Serialize, Deserialize)]
224pub struct ModelDetails {
225    pub format: String,
226    pub family: String,
227    pub families: Option<Vec<String>>,
228    pub parameter_size: String,
229    pub quantization_level: String,
230}
231
232pub async fn complete(
233    client: &dyn HttpClient,
234    api_url: &str,
235    request: ChatRequest,
236) -> Result<ChatResponseDelta> {
237    let uri = format!("{api_url}/api/chat");
238    let request_builder = HttpRequest::builder()
239        .method(Method::POST)
240        .uri(uri)
241        .header("Content-Type", "application/json");
242
243    let serialized_request = serde_json::to_string(&request)?;
244    let request = request_builder.body(AsyncBody::from(serialized_request))?;
245
246    let mut response = client.send(request).await?;
247    if response.status().is_success() {
248        let mut body = Vec::new();
249        response.body_mut().read_to_end(&mut body).await?;
250        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
251        Ok(response_message)
252    } else {
253        let mut body = Vec::new();
254        response.body_mut().read_to_end(&mut body).await?;
255        let body_str = std::str::from_utf8(&body)?;
256        Err(anyhow!(
257            "Failed to connect to API: {} {}",
258            response.status(),
259            body_str
260        ))
261    }
262}
263
264pub async fn stream_chat_completion(
265    client: &dyn HttpClient,
266    api_url: &str,
267    request: ChatRequest,
268    low_speed_timeout: Option<Duration>,
269) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
270    let uri = format!("{api_url}/api/chat");
271    let mut request_builder = HttpRequest::builder()
272        .method(Method::POST)
273        .uri(uri)
274        .header("Content-Type", "application/json");
275
276    if let Some(low_speed_timeout) = low_speed_timeout {
277        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
278    };
279
280    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
281    let mut response = client.send(request).await?;
282    if response.status().is_success() {
283        let reader = BufReader::new(response.into_body());
284
285        Ok(reader
286            .lines()
287            .filter_map(|line| async move {
288                match line {
289                    Ok(line) => {
290                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
291                    }
292                    Err(e) => Some(Err(e.into())),
293                }
294            })
295            .boxed())
296    } else {
297        let mut body = String::new();
298        response.body_mut().read_to_string(&mut body).await?;
299
300        Err(anyhow!(
301            "Failed to connect to Ollama API: {} {}",
302            response.status(),
303            body,
304        ))
305    }
306}
307
308pub async fn get_models(
309    client: &dyn HttpClient,
310    api_url: &str,
311    low_speed_timeout: Option<Duration>,
312) -> Result<Vec<LocalModelListing>> {
313    let uri = format!("{api_url}/api/tags");
314    let mut request_builder = HttpRequest::builder()
315        .method(Method::GET)
316        .uri(uri)
317        .header("Accept", "application/json");
318
319    if let Some(low_speed_timeout) = low_speed_timeout {
320        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
321    };
322
323    let request = request_builder.body(AsyncBody::default())?;
324
325    let mut response = client.send(request).await?;
326
327    let mut body = String::new();
328    response.body_mut().read_to_string(&mut body).await?;
329
330    if response.status().is_success() {
331        let response: LocalModelsResponse =
332            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
333
334        Ok(response.models)
335    } else {
336        Err(anyhow!(
337            "Failed to connect to Ollama API: {} {}",
338            response.status(),
339            body,
340        ))
341    }
342}
343
344/// Sends an empty request to Ollama to trigger loading the model
345pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
346    let uri = format!("{api_url}/api/generate");
347    let request = HttpRequest::builder()
348        .method(Method::POST)
349        .uri(uri)
350        .header("Content-Type", "application/json")
351        .body(AsyncBody::from(serde_json::to_string(
352            &serde_json::json!({
353                "model": model,
354                "keep_alive": "15m",
355            }),
356        )?))?;
357
358    let mut response = match client.send(request).await {
359        Ok(response) => response,
360        Err(err) => {
361            // Be ok with a timeout during preload of the model
362            if err.is_timeout() {
363                return Ok(());
364            } else {
365                return Err(err.into());
366            }
367        }
368    };
369
370    if response.status().is_success() {
371        Ok(())
372    } else {
373        let mut body = String::new();
374        response.body_mut().read_to_string(&mut body).await?;
375
376        Err(anyhow!(
377            "Failed to connect to Ollama API: {} {}",
378            response.status(),
379            body,
380        ))
381    }
382}