lmstudio.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, time::Duration};
  7
  8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => anyhow::bail!("invalid role '{value}'"),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub struct Model {
 47    pub name: String,
 48    pub display_name: Option<String>,
 49    pub max_tokens: u64,
 50    pub supports_tool_calls: bool,
 51    pub supports_images: bool,
 52}
 53
 54impl Model {
 55    pub fn new(
 56        name: &str,
 57        display_name: Option<&str>,
 58        max_tokens: Option<u64>,
 59        supports_tool_calls: bool,
 60        supports_images: bool,
 61    ) -> Self {
 62        Self {
 63            name: name.to_owned(),
 64            display_name: display_name.map(|s| s.to_owned()),
 65            max_tokens: max_tokens.unwrap_or(2048),
 66            supports_tool_calls,
 67            supports_images,
 68        }
 69    }
 70
 71    pub fn id(&self) -> &str {
 72        &self.name
 73    }
 74
 75    pub fn display_name(&self) -> &str {
 76        self.display_name.as_ref().unwrap_or(&self.name)
 77    }
 78
 79    pub fn max_token_count(&self) -> u64 {
 80        self.max_tokens
 81    }
 82
 83    pub fn supports_tool_calls(&self) -> bool {
 84        self.supports_tool_calls
 85    }
 86}
 87
 88#[derive(Debug, Serialize, Deserialize)]
 89#[serde(untagged)]
 90pub enum ToolChoice {
 91    Auto,
 92    Required,
 93    None,
 94    Other(ToolDefinition),
 95}
 96
 97#[derive(Clone, Deserialize, Serialize, Debug)]
 98#[serde(tag = "type", rename_all = "snake_case")]
 99pub enum ToolDefinition {
100    #[allow(dead_code)]
101    Function { function: FunctionDefinition },
102}
103
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct FunctionDefinition {
106    pub name: String,
107    pub description: Option<String>,
108    pub parameters: Option<Value>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112#[serde(tag = "role", rename_all = "lowercase")]
113pub enum ChatMessage {
114    Assistant {
115        #[serde(default)]
116        content: Option<MessageContent>,
117        #[serde(default, skip_serializing_if = "Vec::is_empty")]
118        tool_calls: Vec<ToolCall>,
119    },
120    User {
121        content: MessageContent,
122    },
123    System {
124        content: MessageContent,
125    },
126    Tool {
127        content: MessageContent,
128        tool_call_id: String,
129    },
130}
131
132#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
133#[serde(untagged)]
134pub enum MessageContent {
135    Plain(String),
136    Multipart(Vec<MessagePart>),
137}
138
139impl MessageContent {
140    pub fn empty() -> Self {
141        MessageContent::Multipart(vec![])
142    }
143
144    pub fn push_part(&mut self, part: MessagePart) {
145        match self {
146            MessageContent::Plain(text) => {
147                *self =
148                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
149            }
150            MessageContent::Multipart(parts) if parts.is_empty() => match part {
151                MessagePart::Text { text } => *self = MessageContent::Plain(text),
152                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
153            },
154            MessageContent::Multipart(parts) => parts.push(part),
155        }
156    }
157}
158
159impl From<Vec<MessagePart>> for MessageContent {
160    fn from(mut parts: Vec<MessagePart>) -> Self {
161        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
162            MessageContent::Plain(std::mem::take(text))
163        } else {
164            MessageContent::Multipart(parts)
165        }
166    }
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "type", rename_all = "snake_case")]
171pub enum MessagePart {
172    Text {
173        text: String,
174    },
175    #[serde(rename = "image_url")]
176    Image {
177        image_url: ImageUrl,
178    },
179}
180
181#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
182pub struct ImageUrl {
183    pub url: String,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub detail: Option<String>,
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189pub struct ToolCall {
190    pub id: String,
191    #[serde(flatten)]
192    pub content: ToolCallContent,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "type", rename_all = "lowercase")]
197pub enum ToolCallContent {
198    Function { function: FunctionContent },
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct FunctionContent {
203    pub name: String,
204    pub arguments: String,
205}
206
207#[derive(Serialize, Debug)]
208pub struct ChatCompletionRequest {
209    pub model: String,
210    pub messages: Vec<ChatMessage>,
211    pub stream: bool,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub max_tokens: Option<i32>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub stop: Option<Vec<String>>,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub temperature: Option<f32>,
218    #[serde(skip_serializing_if = "Vec::is_empty")]
219    pub tools: Vec<ToolDefinition>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub tool_choice: Option<ToolChoice>,
222}
223
224#[derive(Serialize, Deserialize, Debug)]
225pub struct ChatResponse {
226    pub id: String,
227    pub object: String,
228    pub created: u64,
229    pub model: String,
230    pub choices: Vec<ChoiceDelta>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct ChoiceDelta {
235    pub index: u32,
236    pub delta: ResponseMessageDelta,
237    pub finish_reason: Option<String>,
238}
239
240#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
241pub struct ToolCallChunk {
242    pub index: usize,
243    pub id: Option<String>,
244
245    // There is also an optional `type` field that would determine if a
246    // function is there. Sometimes this streams in with the `function` before
247    // it streams in the `type`
248    pub function: Option<FunctionChunk>,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252pub struct FunctionChunk {
253    pub name: Option<String>,
254    pub arguments: Option<String>,
255}
256
257#[derive(Serialize, Deserialize, Debug)]
258pub struct Usage {
259    pub prompt_tokens: u64,
260    pub completion_tokens: u64,
261    pub total_tokens: u64,
262}
263
264#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
265#[serde(transparent)]
266pub struct Capabilities(Vec<String>);
267
268impl Capabilities {
269    pub fn supports_tool_calls(&self) -> bool {
270        self.0.iter().any(|cap| cap == "tool_use")
271    }
272
273    pub fn supports_images(&self) -> bool {
274        self.0.iter().any(|cap| cap == "vision")
275    }
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279#[serde(untagged)]
280pub enum ResponseStreamResult {
281    Ok(ResponseStreamEvent),
282    Err { error: String },
283}
284
285#[derive(Serialize, Deserialize, Debug)]
286pub struct ResponseStreamEvent {
287    pub created: u32,
288    pub model: String,
289    pub object: String,
290    pub choices: Vec<ChoiceDelta>,
291    pub usage: Option<Usage>,
292}
293
294#[derive(Deserialize)]
295pub struct ListModelsResponse {
296    pub data: Vec<ModelEntry>,
297}
298
299#[derive(Clone, Debug, Deserialize, PartialEq)]
300pub struct ModelEntry {
301    pub id: String,
302    pub object: String,
303    pub r#type: ModelType,
304    pub publisher: String,
305    pub arch: Option<String>,
306    pub compatibility_type: CompatibilityType,
307    pub quantization: Option<String>,
308    pub state: ModelState,
309    pub max_context_length: Option<u64>,
310    pub loaded_context_length: Option<u64>,
311    #[serde(default)]
312    pub capabilities: Capabilities,
313}
314
315#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
316#[serde(rename_all = "lowercase")]
317pub enum ModelType {
318    Llm,
319    Embeddings,
320    Vlm,
321}
322
323#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
324#[serde(rename_all = "kebab-case")]
325pub enum ModelState {
326    Loaded,
327    Loading,
328    NotLoaded,
329}
330
331#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
332#[serde(rename_all = "lowercase")]
333pub enum CompatibilityType {
334    Gguf,
335    Mlx,
336}
337
338#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
339pub struct ResponseMessageDelta {
340    pub role: Option<Role>,
341    pub content: Option<String>,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub reasoning_content: Option<String>,
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub tool_calls: Option<Vec<ToolCallChunk>>,
346}
347
348pub async fn complete(
349    client: &dyn HttpClient,
350    api_url: &str,
351    request: ChatCompletionRequest,
352) -> Result<ChatResponse> {
353    let uri = format!("{api_url}/chat/completions");
354    let request_builder = HttpRequest::builder()
355        .method(Method::POST)
356        .uri(uri)
357        .header("Content-Type", "application/json");
358
359    let serialized_request = serde_json::to_string(&request)?;
360    let request = request_builder.body(AsyncBody::from(serialized_request))?;
361
362    let mut response = client.send(request).await?;
363    if response.status().is_success() {
364        let mut body = Vec::new();
365        response.body_mut().read_to_end(&mut body).await?;
366        let response_message: ChatResponse = serde_json::from_slice(&body)?;
367        Ok(response_message)
368    } else {
369        let mut body = Vec::new();
370        response.body_mut().read_to_end(&mut body).await?;
371        let body_str = std::str::from_utf8(&body)?;
372        anyhow::bail!(
373            "Failed to connect to API: {} {}",
374            response.status(),
375            body_str
376        );
377    }
378}
379
380pub async fn stream_chat_completion(
381    client: &dyn HttpClient,
382    api_url: &str,
383    request: ChatCompletionRequest,
384) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
385    let uri = format!("{api_url}/chat/completions");
386    let request_builder = http::Request::builder()
387        .method(Method::POST)
388        .uri(uri)
389        .header("Content-Type", "application/json");
390
391    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
392    let mut response = client.send(request).await?;
393    if response.status().is_success() {
394        let reader = BufReader::new(response.into_body());
395
396        Ok(reader
397            .lines()
398            .filter_map(|line| async move {
399                match line {
400                    Ok(line) => {
401                        let line = line.strip_prefix("data: ")?;
402                        if line == "[DONE]" {
403                            None
404                        } else {
405                            let result = serde_json::from_str(&line)
406                                .context("Unable to parse chat completions response");
407                            if let Err(ref e) = result {
408                                eprintln!("Error parsing line: {e}\nLine content: '{line}'");
409                            }
410                            Some(result)
411                        }
412                    }
413                    Err(e) => {
414                        eprintln!("Error reading line: {e}");
415                        Some(Err(e.into()))
416                    }
417                }
418            })
419            .boxed())
420    } else {
421        let mut body = String::new();
422        response.body_mut().read_to_string(&mut body).await?;
423        anyhow::bail!(
424            "Failed to connect to LM Studio API: {} {}",
425            response.status(),
426            body,
427        );
428    }
429}
430
431pub async fn get_models(
432    client: &dyn HttpClient,
433    api_url: &str,
434    _: Option<Duration>,
435) -> Result<Vec<ModelEntry>> {
436    let uri = format!("{api_url}/models");
437    let request_builder = HttpRequest::builder()
438        .method(Method::GET)
439        .uri(uri)
440        .header("Accept", "application/json");
441
442    let request = request_builder.body(AsyncBody::default())?;
443
444    let mut response = client.send(request).await?;
445
446    let mut body = String::new();
447    response.body_mut().read_to_string(&mut body).await?;
448
449    anyhow::ensure!(
450        response.status().is_success(),
451        "Failed to connect to LM Studio API: {} {}",
452        response.status(),
453        body,
454    );
455    let response: ListModelsResponse =
456        serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
457    Ok(response.data)
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn test_image_message_part_serialization() {
466        let image_part = MessagePart::Image {
467            image_url: ImageUrl {
468                url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
469                detail: None,
470            },
471        };
472
473        let json = serde_json::to_string(&image_part).unwrap();
474        println!("Serialized image part: {}", json);
475
476        // Verify the structure matches what LM Studio expects
477        let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
478        assert_eq!(json, expected_structure);
479    }
480
481    #[test]
482    fn test_text_message_part_serialization() {
483        let text_part = MessagePart::Text {
484            text: "Hello, world!".to_string(),
485        };
486
487        let json = serde_json::to_string(&text_part).unwrap();
488        println!("Serialized text part: {}", json);
489
490        let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
491        assert_eq!(json, expected_structure);
492    }
493}