lmstudio.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, time::Duration};
  7
  8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => anyhow::bail!("invalid role '{value}'"),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub struct Model {
 47    pub name: String,
 48    pub display_name: Option<String>,
 49    pub max_tokens: u64,
 50    pub supports_tool_calls: bool,
 51    pub supports_images: bool,
 52}
 53
 54impl Model {
 55    pub fn new(
 56        name: &str,
 57        display_name: Option<&str>,
 58        max_tokens: Option<u64>,
 59        supports_tool_calls: bool,
 60        supports_images: bool,
 61    ) -> Self {
 62        Self {
 63            name: name.to_owned(),
 64            display_name: display_name.map(|s| s.to_owned()),
 65            max_tokens: max_tokens.unwrap_or(2048),
 66            supports_tool_calls,
 67            supports_images,
 68        }
 69    }
 70
 71    pub fn id(&self) -> &str {
 72        &self.name
 73    }
 74
 75    pub fn display_name(&self) -> &str {
 76        self.display_name.as_ref().unwrap_or(&self.name)
 77    }
 78
 79    pub fn max_token_count(&self) -> u64 {
 80        self.max_tokens
 81    }
 82
 83    pub fn supports_tool_calls(&self) -> bool {
 84        self.supports_tool_calls
 85    }
 86}
 87
 88#[derive(Debug, Serialize, Deserialize)]
 89#[serde(untagged)]
 90pub enum ToolChoice {
 91    Auto,
 92    Required,
 93    None,
 94    Other(ToolDefinition),
 95}
 96
 97#[derive(Clone, Deserialize, Serialize, Debug)]
 98#[serde(tag = "type", rename_all = "snake_case")]
 99pub enum ToolDefinition {
100    #[allow(dead_code)]
101    Function { function: FunctionDefinition },
102}
103
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct FunctionDefinition {
106    pub name: String,
107    pub description: Option<String>,
108    pub parameters: Option<Value>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112#[serde(tag = "role", rename_all = "lowercase")]
113pub enum ChatMessage {
114    Assistant {
115        #[serde(default)]
116        content: Option<MessageContent>,
117        #[serde(default, skip_serializing_if = "Vec::is_empty")]
118        tool_calls: Vec<ToolCall>,
119    },
120    User {
121        content: MessageContent,
122    },
123    System {
124        content: MessageContent,
125    },
126    Tool {
127        content: MessageContent,
128        tool_call_id: String,
129    },
130}
131
132#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
133#[serde(untagged)]
134pub enum MessageContent {
135    Plain(String),
136    Multipart(Vec<MessagePart>),
137}
138
139impl MessageContent {
140    pub fn empty() -> Self {
141        MessageContent::Multipart(vec![])
142    }
143
144    pub fn push_part(&mut self, part: MessagePart) {
145        match self {
146            MessageContent::Plain(text) => {
147                *self =
148                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
149            }
150            MessageContent::Multipart(parts) if parts.is_empty() => match part {
151                MessagePart::Text { text } => *self = MessageContent::Plain(text),
152                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
153            },
154            MessageContent::Multipart(parts) => parts.push(part),
155        }
156    }
157}
158
159impl From<Vec<MessagePart>> for MessageContent {
160    fn from(mut parts: Vec<MessagePart>) -> Self {
161        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
162            MessageContent::Plain(std::mem::take(text))
163        } else {
164            MessageContent::Multipart(parts)
165        }
166    }
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "type", rename_all = "snake_case")]
171pub enum MessagePart {
172    Text {
173        text: String,
174    },
175    #[serde(rename = "image_url")]
176    Image {
177        image_url: ImageUrl,
178    },
179}
180
181#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
182pub struct ImageUrl {
183    pub url: String,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub detail: Option<String>,
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189pub struct ToolCall {
190    pub id: String,
191    #[serde(flatten)]
192    pub content: ToolCallContent,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "type", rename_all = "lowercase")]
197pub enum ToolCallContent {
198    Function { function: FunctionContent },
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct FunctionContent {
203    pub name: String,
204    pub arguments: String,
205}
206
207#[derive(Serialize, Debug)]
208pub struct ChatCompletionRequest {
209    pub model: String,
210    pub messages: Vec<ChatMessage>,
211    pub stream: bool,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub max_tokens: Option<i32>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub stop: Option<Vec<String>>,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub temperature: Option<f32>,
218    #[serde(skip_serializing_if = "Vec::is_empty")]
219    pub tools: Vec<ToolDefinition>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub tool_choice: Option<ToolChoice>,
222}
223
224#[derive(Serialize, Deserialize, Debug)]
225pub struct ChatResponse {
226    pub id: String,
227    pub object: String,
228    pub created: u64,
229    pub model: String,
230    pub choices: Vec<ChoiceDelta>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct ChoiceDelta {
235    pub index: u32,
236    pub delta: ResponseMessageDelta,
237    pub finish_reason: Option<String>,
238}
239
240#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
241pub struct ToolCallChunk {
242    pub index: usize,
243    pub id: Option<String>,
244
245    // There is also an optional `type` field that would determine if a
246    // function is there. Sometimes this streams in with the `function` before
247    // it streams in the `type`
248    pub function: Option<FunctionChunk>,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252pub struct FunctionChunk {
253    pub name: Option<String>,
254    pub arguments: Option<String>,
255}
256
257#[derive(Serialize, Deserialize, Debug)]
258pub struct Usage {
259    pub prompt_tokens: u64,
260    pub completion_tokens: u64,
261    pub total_tokens: u64,
262}
263
264#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
265#[serde(transparent)]
266pub struct Capabilities(Vec<String>);
267
268impl Capabilities {
269    pub fn supports_tool_calls(&self) -> bool {
270        self.0.iter().any(|cap| cap == "tool_use")
271    }
272
273    pub fn supports_images(&self) -> bool {
274        self.0.iter().any(|cap| cap == "vision")
275    }
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279pub struct LmStudioError {
280    pub message: String,
281}
282
283#[derive(Serialize, Deserialize, Debug)]
284#[serde(untagged)]
285pub enum ResponseStreamResult {
286    Ok(ResponseStreamEvent),
287    Err { error: LmStudioError },
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct ResponseStreamEvent {
292    pub created: u32,
293    pub model: String,
294    pub object: String,
295    pub choices: Vec<ChoiceDelta>,
296    pub usage: Option<Usage>,
297}
298
299#[derive(Deserialize)]
300pub struct ListModelsResponse {
301    pub data: Vec<ModelEntry>,
302}
303
304#[derive(Clone, Debug, Deserialize, PartialEq)]
305pub struct ModelEntry {
306    pub id: String,
307    pub object: String,
308    pub r#type: ModelType,
309    pub publisher: String,
310    pub arch: Option<String>,
311    pub compatibility_type: CompatibilityType,
312    pub quantization: Option<String>,
313    pub state: ModelState,
314    pub max_context_length: Option<u64>,
315    pub loaded_context_length: Option<u64>,
316    #[serde(default)]
317    pub capabilities: Capabilities,
318}
319
320#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
321#[serde(rename_all = "lowercase")]
322pub enum ModelType {
323    Llm,
324    Embeddings,
325    Vlm,
326}
327
328#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
329#[serde(rename_all = "kebab-case")]
330pub enum ModelState {
331    Loaded,
332    Loading,
333    NotLoaded,
334}
335
336#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
337#[serde(rename_all = "lowercase")]
338pub enum CompatibilityType {
339    Gguf,
340    Mlx,
341}
342
343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
344pub struct ResponseMessageDelta {
345    pub role: Option<Role>,
346    pub content: Option<String>,
347    #[serde(default, skip_serializing_if = "Option::is_none")]
348    pub reasoning_content: Option<String>,
349    #[serde(default, skip_serializing_if = "Option::is_none")]
350    pub tool_calls: Option<Vec<ToolCallChunk>>,
351}
352
353pub async fn complete(
354    client: &dyn HttpClient,
355    api_url: &str,
356    request: ChatCompletionRequest,
357) -> Result<ChatResponse> {
358    let uri = format!("{api_url}/chat/completions");
359    let request_builder = HttpRequest::builder()
360        .method(Method::POST)
361        .uri(uri)
362        .header("Content-Type", "application/json");
363
364    let serialized_request = serde_json::to_string(&request)?;
365    let request = request_builder.body(AsyncBody::from(serialized_request))?;
366
367    let mut response = client.send(request).await?;
368    if response.status().is_success() {
369        let mut body = Vec::new();
370        response.body_mut().read_to_end(&mut body).await?;
371        let response_message: ChatResponse = serde_json::from_slice(&body)?;
372        Ok(response_message)
373    } else {
374        let mut body = Vec::new();
375        response.body_mut().read_to_end(&mut body).await?;
376        let body_str = std::str::from_utf8(&body)?;
377        anyhow::bail!(
378            "Failed to connect to API: {} {}",
379            response.status(),
380            body_str
381        );
382    }
383}
384
385pub async fn stream_chat_completion(
386    client: &dyn HttpClient,
387    api_url: &str,
388    request: ChatCompletionRequest,
389) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
390    let uri = format!("{api_url}/chat/completions");
391    let request_builder = http::Request::builder()
392        .method(Method::POST)
393        .uri(uri)
394        .header("Content-Type", "application/json");
395
396    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
397    let mut response = client.send(request).await?;
398    if response.status().is_success() {
399        let reader = BufReader::new(response.into_body());
400        Ok(reader
401            .lines()
402            .filter_map(|line| async move {
403                match line {
404                    Ok(line) => {
405                        let line = line.strip_prefix("data: ")?;
406                        if line == "[DONE]" {
407                            None
408                        } else {
409                            match serde_json::from_str(line) {
410                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
411                                Ok(ResponseStreamResult::Err { error, .. }) => {
412                                    Some(Err(anyhow!(error.message)))
413                                }
414                                Err(error) => Some(Err(anyhow!(error))),
415                            }
416                        }
417                    }
418                    Err(error) => Some(Err(anyhow!(error))),
419                }
420            })
421            .boxed())
422    } else {
423        let mut body = String::new();
424        response.body_mut().read_to_string(&mut body).await?;
425        anyhow::bail!(
426            "Failed to connect to LM Studio API: {} {}",
427            response.status(),
428            body,
429        );
430    }
431}
432
433pub async fn get_models(
434    client: &dyn HttpClient,
435    api_url: &str,
436    _: Option<Duration>,
437) -> Result<Vec<ModelEntry>> {
438    let uri = format!("{api_url}/models");
439    let request_builder = HttpRequest::builder()
440        .method(Method::GET)
441        .uri(uri)
442        .header("Accept", "application/json");
443
444    let request = request_builder.body(AsyncBody::default())?;
445
446    let mut response = client.send(request).await?;
447
448    let mut body = String::new();
449    response.body_mut().read_to_string(&mut body).await?;
450
451    anyhow::ensure!(
452        response.status().is_success(),
453        "Failed to connect to LM Studio API: {} {}",
454        response.status(),
455        body,
456    );
457    let response: ListModelsResponse =
458        serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
459    Ok(response.data)
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_image_message_part_serialization() {
468        let image_part = MessagePart::Image {
469            image_url: ImageUrl {
470                url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
471                detail: None,
472            },
473        };
474
475        let json = serde_json::to_string(&image_part).unwrap();
476        println!("Serialized image part: {}", json);
477
478        // Verify the structure matches what LM Studio expects
479        let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
480        assert_eq!(json, expected_structure);
481    }
482
483    #[test]
484    fn test_text_message_part_serialization() {
485        let text_part = MessagePart::Text {
486            text: "Hello, world!".to_string(),
487        };
488
489        let json = serde_json::to_string(&text_part).unwrap();
490        println!("Serialized text part: {}", json);
491
492        let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
493        assert_eq!(json, expected_structure);
494    }
495}