lmstudio.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, sync::Arc, time::Duration};
  7
  8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => anyhow::bail!("invalid role '{value}'"),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub struct Model {
 47    pub name: String,
 48    pub display_name: Option<String>,
 49    pub max_tokens: usize,
 50    pub supports_tool_calls: bool,
 51}
 52
 53impl Model {
 54    pub fn new(
 55        name: &str,
 56        display_name: Option<&str>,
 57        max_tokens: Option<usize>,
 58        supports_tool_calls: bool,
 59    ) -> Self {
 60        Self {
 61            name: name.to_owned(),
 62            display_name: display_name.map(|s| s.to_owned()),
 63            max_tokens: max_tokens.unwrap_or(2048),
 64            supports_tool_calls,
 65        }
 66    }
 67
 68    pub fn id(&self) -> &str {
 69        &self.name
 70    }
 71
 72    pub fn display_name(&self) -> &str {
 73        self.display_name.as_ref().unwrap_or(&self.name)
 74    }
 75
 76    pub fn max_token_count(&self) -> usize {
 77        self.max_tokens
 78    }
 79
 80    pub fn supports_tool_calls(&self) -> bool {
 81        self.supports_tool_calls
 82    }
 83}
 84
 85#[derive(Debug, Serialize, Deserialize)]
 86#[serde(untagged)]
 87pub enum ToolChoice {
 88    Auto,
 89    Required,
 90    None,
 91    Other(ToolDefinition),
 92}
 93
 94#[derive(Clone, Deserialize, Serialize, Debug)]
 95#[serde(tag = "type", rename_all = "snake_case")]
 96pub enum ToolDefinition {
 97    #[allow(dead_code)]
 98    Function { function: FunctionDefinition },
 99}
100
101#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct FunctionDefinition {
103    pub name: String,
104    pub description: Option<String>,
105    pub parameters: Option<Value>,
106}
107
108#[derive(Serialize, Deserialize, Debug)]
109#[serde(tag = "role", rename_all = "lowercase")]
110pub enum ChatMessage {
111    Assistant {
112        #[serde(default)]
113        content: Option<String>,
114        #[serde(default, skip_serializing_if = "Vec::is_empty")]
115        tool_calls: Vec<ToolCall>,
116    },
117    User {
118        content: String,
119    },
120    System {
121        content: String,
122    },
123    Tool {
124        content: String,
125        tool_call_id: String,
126    },
127}
128
129#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
130pub struct ToolCall {
131    pub id: String,
132    #[serde(flatten)]
133    pub content: ToolCallContent,
134}
135
136#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
137#[serde(tag = "type", rename_all = "lowercase")]
138pub enum ToolCallContent {
139    Function { function: FunctionContent },
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143pub struct FunctionContent {
144    pub name: String,
145    pub arguments: String,
146}
147
148#[derive(Serialize, Debug)]
149pub struct ChatCompletionRequest {
150    pub model: String,
151    pub messages: Vec<ChatMessage>,
152    pub stream: bool,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub max_tokens: Option<i32>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub stop: Option<Vec<String>>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub temperature: Option<f32>,
159    #[serde(skip_serializing_if = "Vec::is_empty")]
160    pub tools: Vec<ToolDefinition>,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub tool_choice: Option<ToolChoice>,
163}
164
165#[derive(Serialize, Deserialize, Debug)]
166pub struct ChatResponse {
167    pub id: String,
168    pub object: String,
169    pub created: u64,
170    pub model: String,
171    pub choices: Vec<ChoiceDelta>,
172}
173
174#[derive(Serialize, Deserialize, Debug)]
175pub struct ChoiceDelta {
176    pub index: u32,
177    pub delta: ResponseMessageDelta,
178    pub finish_reason: Option<String>,
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182pub struct ToolCallChunk {
183    pub index: usize,
184    pub id: Option<String>,
185
186    // There is also an optional `type` field that would determine if a
187    // function is there. Sometimes this streams in with the `function` before
188    // it streams in the `type`
189    pub function: Option<FunctionChunk>,
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193pub struct FunctionChunk {
194    pub name: Option<String>,
195    pub arguments: Option<String>,
196}
197
198#[derive(Serialize, Deserialize, Debug)]
199pub struct Usage {
200    pub prompt_tokens: u32,
201    pub completion_tokens: u32,
202    pub total_tokens: u32,
203}
204
205#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
206#[serde(transparent)]
207pub struct Capabilities(Vec<String>);
208
209impl Capabilities {
210    pub fn supports_tool_calls(&self) -> bool {
211        self.0.iter().any(|cap| cap == "tool_use")
212    }
213}
214
215#[derive(Serialize, Deserialize, Debug)]
216#[serde(untagged)]
217pub enum ResponseStreamResult {
218    Ok(ResponseStreamEvent),
219    Err { error: String },
220}
221
222#[derive(Serialize, Deserialize, Debug)]
223pub struct ResponseStreamEvent {
224    pub created: u32,
225    pub model: String,
226    pub object: String,
227    pub choices: Vec<ChoiceDelta>,
228    pub usage: Option<Usage>,
229}
230
231#[derive(Deserialize)]
232pub struct ListModelsResponse {
233    pub data: Vec<ModelEntry>,
234}
235
236#[derive(Clone, Debug, Deserialize, PartialEq)]
237pub struct ModelEntry {
238    pub id: String,
239    pub object: String,
240    pub r#type: ModelType,
241    pub publisher: String,
242    pub arch: Option<String>,
243    pub compatibility_type: CompatibilityType,
244    pub quantization: Option<String>,
245    pub state: ModelState,
246    pub max_context_length: Option<u32>,
247    pub loaded_context_length: Option<u32>,
248    #[serde(default)]
249    pub capabilities: Capabilities,
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
253#[serde(rename_all = "lowercase")]
254pub enum ModelType {
255    Llm,
256    Embeddings,
257    Vlm,
258}
259
260#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
261#[serde(rename_all = "kebab-case")]
262pub enum ModelState {
263    Loaded,
264    Loading,
265    NotLoaded,
266}
267
268#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
269#[serde(rename_all = "lowercase")]
270pub enum CompatibilityType {
271    Gguf,
272    Mlx,
273}
274
275#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
276pub struct ResponseMessageDelta {
277    pub role: Option<Role>,
278    pub content: Option<String>,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub tool_calls: Option<Vec<ToolCallChunk>>,
281}
282
283pub async fn complete(
284    client: &dyn HttpClient,
285    api_url: &str,
286    request: ChatCompletionRequest,
287) -> Result<ChatResponse> {
288    let uri = format!("{api_url}/chat/completions");
289    let request_builder = HttpRequest::builder()
290        .method(Method::POST)
291        .uri(uri)
292        .header("Content-Type", "application/json");
293
294    let serialized_request = serde_json::to_string(&request)?;
295    let request = request_builder.body(AsyncBody::from(serialized_request))?;
296
297    let mut response = client.send(request).await?;
298    if response.status().is_success() {
299        let mut body = Vec::new();
300        response.body_mut().read_to_end(&mut body).await?;
301        let response_message: ChatResponse = serde_json::from_slice(&body)?;
302        Ok(response_message)
303    } else {
304        let mut body = Vec::new();
305        response.body_mut().read_to_end(&mut body).await?;
306        let body_str = std::str::from_utf8(&body)?;
307        anyhow::bail!(
308            "Failed to connect to API: {} {}",
309            response.status(),
310            body_str
311        );
312    }
313}
314
315pub async fn stream_chat_completion(
316    client: &dyn HttpClient,
317    api_url: &str,
318    request: ChatCompletionRequest,
319) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
320    let uri = format!("{api_url}/chat/completions");
321    let request_builder = http::Request::builder()
322        .method(Method::POST)
323        .uri(uri)
324        .header("Content-Type", "application/json");
325
326    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
327    let mut response = client.send(request).await?;
328    if response.status().is_success() {
329        let reader = BufReader::new(response.into_body());
330
331        Ok(reader
332            .lines()
333            .filter_map(|line| async move {
334                match line {
335                    Ok(line) => {
336                        let line = line.strip_prefix("data: ")?;
337                        if line == "[DONE]" {
338                            None
339                        } else {
340                            let result = serde_json::from_str(&line)
341                                .context("Unable to parse chat completions response");
342                            if let Err(ref e) = result {
343                                eprintln!("Error parsing line: {e}\nLine content: '{line}'");
344                            }
345                            Some(result)
346                        }
347                    }
348                    Err(e) => {
349                        eprintln!("Error reading line: {e}");
350                        Some(Err(e.into()))
351                    }
352                }
353            })
354            .boxed())
355    } else {
356        let mut body = String::new();
357        response.body_mut().read_to_string(&mut body).await?;
358        anyhow::bail!(
359            "Failed to connect to LM Studio API: {} {}",
360            response.status(),
361            body,
362        );
363    }
364}
365
366pub async fn get_models(
367    client: &dyn HttpClient,
368    api_url: &str,
369    _: Option<Duration>,
370) -> Result<Vec<ModelEntry>> {
371    let uri = format!("{api_url}/models");
372    let request_builder = HttpRequest::builder()
373        .method(Method::GET)
374        .uri(uri)
375        .header("Accept", "application/json");
376
377    let request = request_builder.body(AsyncBody::default())?;
378
379    let mut response = client.send(request).await?;
380
381    let mut body = String::new();
382    response.body_mut().read_to_string(&mut body).await?;
383
384    anyhow::ensure!(
385        response.status().is_success(),
386        "Failed to connect to LM Studio API: {} {}",
387        response.status(),
388        body,
389    );
390    let response: ListModelsResponse =
391        serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
392    Ok(response.data)
393}
394
395/// Sends an empty request to LM Studio to trigger loading the model
396pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
397    let uri = format!("{api_url}/completions");
398    let request = HttpRequest::builder()
399        .method(Method::POST)
400        .uri(uri)
401        .header("Content-Type", "application/json")
402        .body(AsyncBody::from(serde_json::to_string(
403            &serde_json::json!({
404                "model": model,
405                "messages": [],
406                "stream": false,
407                "max_tokens": 0,
408            }),
409        )?))?;
410
411    let mut response = client.send(request).await?;
412
413    if response.status().is_success() {
414        Ok(())
415    } else {
416        let mut body = String::new();
417        response.body_mut().read_to_string(&mut body).await?;
418        anyhow::bail!(
419            "Failed to connect to LM Studio API: {} {}",
420            response.status(),
421            body,
422        );
423    }
424}