ollama.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{http, AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use schemars::JsonSchema;
  5use serde::{Deserialize, Serialize};
  6use serde_json::{value::RawValue, Value};
  7use std::{convert::TryFrom, sync::Arc, time::Duration};
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            _ => Err(anyhow!("invalid role '{value}'")),
 28        }
 29    }
 30}
 31
 32impl From<Role> for String {
 33    fn from(val: Role) -> Self {
 34        match val {
 35            Role::User => "user".to_owned(),
 36            Role::Assistant => "assistant".to_owned(),
 37            Role::System => "system".to_owned(),
 38        }
 39    }
 40}
 41
 42#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
 43#[serde(untagged)]
 44pub enum KeepAlive {
 45    /// Keep model alive for N seconds
 46    Seconds(isize),
 47    /// Keep model alive for a fixed duration. Accepts durations like "5m", "10m", "1h", "1d", etc.
 48    Duration(String),
 49}
 50
 51impl KeepAlive {
 52    /// Keep model alive until a new model is loaded or until Ollama shuts down
 53    fn indefinite() -> Self {
 54        Self::Seconds(-1)
 55    }
 56}
 57
 58impl Default for KeepAlive {
 59    fn default() -> Self {
 60        Self::indefinite()
 61    }
 62}
 63
 64#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 66pub struct Model {
 67    pub name: String,
 68    pub display_name: Option<String>,
 69    pub max_tokens: usize,
 70    pub keep_alive: Option<KeepAlive>,
 71}
 72
 73fn get_max_tokens(name: &str) -> usize {
 74    /// Default context length for unknown models.
 75    const DEFAULT_TOKENS: usize = 2048;
 76    /// Magic number. Lets many Ollama models work with ~16GB of ram.
 77    const MAXIMUM_TOKENS: usize = 16384;
 78
 79    match name.split(':').next().unwrap() {
 80        "phi" | "tinyllama" | "granite-code" => 2048,
 81        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
 82        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
 83        "codellama" | "starcoder2" => 16384,
 84        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
 85        "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" | "yi-coder" => 128000,
 86        _ => DEFAULT_TOKENS,
 87    }
 88    .clamp(1, MAXIMUM_TOKENS)
 89}
 90
 91impl Model {
 92    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
 93        Self {
 94            name: name.to_owned(),
 95            display_name: display_name
 96                .map(ToString::to_string)
 97                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 98            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 99            keep_alive: Some(KeepAlive::indefinite()),
100        }
101    }
102
103    pub fn id(&self) -> &str {
104        &self.name
105    }
106
107    pub fn display_name(&self) -> &str {
108        self.display_name.as_ref().unwrap_or(&self.name)
109    }
110
111    pub fn max_token_count(&self) -> usize {
112        self.max_tokens
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117#[serde(tag = "role", rename_all = "lowercase")]
118pub enum ChatMessage {
119    Assistant {
120        content: String,
121        tool_calls: Option<Vec<OllamaToolCall>>,
122    },
123    User {
124        content: String,
125    },
126    System {
127        content: String,
128    },
129}
130
131#[derive(Serialize, Deserialize, Debug)]
132#[serde(rename_all = "lowercase")]
133pub enum OllamaToolCall {
134    Function(OllamaFunctionCall),
135}
136
137#[derive(Serialize, Deserialize, Debug)]
138pub struct OllamaFunctionCall {
139    pub name: String,
140    pub arguments: Box<RawValue>,
141}
142
143#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
144pub struct OllamaFunctionTool {
145    pub name: String,
146    pub description: Option<String>,
147    pub parameters: Option<Value>,
148}
149
150#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
151#[serde(tag = "type", rename_all = "lowercase")]
152pub enum OllamaTool {
153    Function { function: OllamaFunctionTool },
154}
155
156#[derive(Serialize, Debug)]
157pub struct ChatRequest {
158    pub model: String,
159    pub messages: Vec<ChatMessage>,
160    pub stream: bool,
161    pub keep_alive: KeepAlive,
162    pub options: Option<ChatOptions>,
163    pub tools: Vec<OllamaTool>,
164}
165
166impl ChatRequest {
167    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
168        self.stream = false;
169        self.tools = tools;
170        self
171    }
172}
173
174// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
175#[derive(Serialize, Default, Debug)]
176pub struct ChatOptions {
177    pub num_ctx: Option<usize>,
178    pub num_predict: Option<isize>,
179    pub stop: Option<Vec<String>>,
180    pub temperature: Option<f32>,
181    pub top_p: Option<f32>,
182}
183
184#[derive(Deserialize, Debug)]
185pub struct ChatResponseDelta {
186    #[allow(unused)]
187    pub model: String,
188    #[allow(unused)]
189    pub created_at: String,
190    pub message: ChatMessage,
191    #[allow(unused)]
192    pub done_reason: Option<String>,
193    #[allow(unused)]
194    pub done: bool,
195}
196
197#[derive(Serialize, Deserialize)]
198pub struct LocalModelsResponse {
199    pub models: Vec<LocalModelListing>,
200}
201
202#[derive(Serialize, Deserialize)]
203pub struct LocalModelListing {
204    pub name: String,
205    pub modified_at: String,
206    pub size: u64,
207    pub digest: String,
208    pub details: ModelDetails,
209}
210
211#[derive(Serialize, Deserialize)]
212pub struct LocalModel {
213    pub modelfile: String,
214    pub parameters: String,
215    pub template: String,
216    pub details: ModelDetails,
217}
218
219#[derive(Serialize, Deserialize)]
220pub struct ModelDetails {
221    pub format: String,
222    pub family: String,
223    pub families: Option<Vec<String>>,
224    pub parameter_size: String,
225    pub quantization_level: String,
226}
227
228pub async fn complete(
229    client: &dyn HttpClient,
230    api_url: &str,
231    request: ChatRequest,
232) -> Result<ChatResponseDelta> {
233    let uri = format!("{api_url}/api/chat");
234    let request_builder = HttpRequest::builder()
235        .method(Method::POST)
236        .uri(uri)
237        .header("Content-Type", "application/json");
238
239    let serialized_request = serde_json::to_string(&request)?;
240    let request = request_builder.body(AsyncBody::from(serialized_request))?;
241
242    let mut response = client.send(request).await?;
243    if response.status().is_success() {
244        let mut body = Vec::new();
245        response.body_mut().read_to_end(&mut body).await?;
246        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
247        Ok(response_message)
248    } else {
249        let mut body = Vec::new();
250        response.body_mut().read_to_end(&mut body).await?;
251        let body_str = std::str::from_utf8(&body)?;
252        Err(anyhow!(
253            "Failed to connect to API: {} {}",
254            response.status(),
255            body_str
256        ))
257    }
258}
259
260pub async fn stream_chat_completion(
261    client: &dyn HttpClient,
262    api_url: &str,
263    request: ChatRequest,
264    _: Option<Duration>,
265) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
266    let uri = format!("{api_url}/api/chat");
267    let request_builder = http::Request::builder()
268        .method(Method::POST)
269        .uri(uri)
270        .header("Content-Type", "application/json");
271
272    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
273    let mut response = client.send(request).await?;
274    if response.status().is_success() {
275        let reader = BufReader::new(response.into_body());
276
277        Ok(reader
278            .lines()
279            .filter_map(|line| async move {
280                match line {
281                    Ok(line) => {
282                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
283                    }
284                    Err(e) => Some(Err(e.into())),
285                }
286            })
287            .boxed())
288    } else {
289        let mut body = String::new();
290        response.body_mut().read_to_string(&mut body).await?;
291
292        Err(anyhow!(
293            "Failed to connect to Ollama API: {} {}",
294            response.status(),
295            body,
296        ))
297    }
298}
299
300pub async fn get_models(
301    client: &dyn HttpClient,
302    api_url: &str,
303    _: Option<Duration>,
304) -> Result<Vec<LocalModelListing>> {
305    let uri = format!("{api_url}/api/tags");
306    let request_builder = HttpRequest::builder()
307        .method(Method::GET)
308        .uri(uri)
309        .header("Accept", "application/json");
310
311    let request = request_builder.body(AsyncBody::default())?;
312
313    let mut response = client.send(request).await?;
314
315    let mut body = String::new();
316    response.body_mut().read_to_string(&mut body).await?;
317
318    if response.status().is_success() {
319        let response: LocalModelsResponse =
320            serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
321
322        Ok(response.models)
323    } else {
324        Err(anyhow!(
325            "Failed to connect to Ollama API: {} {}",
326            response.status(),
327            body,
328        ))
329    }
330}
331
332/// Sends an empty request to Ollama to trigger loading the model
333pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
334    let uri = format!("{api_url}/api/generate");
335    let request = HttpRequest::builder()
336        .method(Method::POST)
337        .uri(uri)
338        .header("Content-Type", "application/json")
339        .body(AsyncBody::from(serde_json::to_string(
340            &serde_json::json!({
341                "model": model,
342                "keep_alive": "15m",
343            }),
344        )?))?;
345
346    let mut response = client.send(request).await?;
347
348    if response.status().is_success() {
349        Ok(())
350    } else {
351        let mut body = String::new();
352        response.body_mut().read_to_string(&mut body).await?;
353
354        Err(anyhow!(
355            "Failed to connect to Ollama API: {} {}",
356            response.status(),
357            body,
358        ))
359    }
360}