lmstudio.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, time::Duration};
  7
  8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => anyhow::bail!("invalid role '{value}'"),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub struct Model {
 47    pub name: String,
 48    pub display_name: Option<String>,
 49    pub max_tokens: u64,
 50    pub supports_tool_calls: bool,
 51    pub supports_images: bool,
 52}
 53
 54impl Model {
 55    pub fn new(
 56        name: &str,
 57        display_name: Option<&str>,
 58        max_tokens: Option<u64>,
 59        supports_tool_calls: bool,
 60        supports_images: bool,
 61    ) -> Self {
 62        Self {
 63            name: name.to_owned(),
 64            display_name: display_name.map(|s| s.to_owned()),
 65            max_tokens: max_tokens.unwrap_or(2048),
 66            supports_tool_calls,
 67            supports_images,
 68        }
 69    }
 70
 71    pub fn id(&self) -> &str {
 72        &self.name
 73    }
 74
 75    pub fn display_name(&self) -> &str {
 76        self.display_name.as_ref().unwrap_or(&self.name)
 77    }
 78
 79    pub const fn max_token_count(&self) -> u64 {
 80        self.max_tokens
 81    }
 82
 83    pub const fn supports_tool_calls(&self) -> bool {
 84        self.supports_tool_calls
 85    }
 86}
 87
 88#[derive(Debug, Serialize, Deserialize)]
 89#[serde(rename_all = "lowercase")]
 90pub enum ToolChoice {
 91    Auto,
 92    Required,
 93    None,
 94    #[serde(untagged)]
 95    Other(ToolDefinition),
 96}
 97
 98#[derive(Clone, Deserialize, Serialize, Debug)]
 99#[serde(tag = "type", rename_all = "snake_case")]
100pub enum ToolDefinition {
101    #[allow(dead_code)]
102    Function { function: FunctionDefinition },
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize)]
106pub struct FunctionDefinition {
107    pub name: String,
108    pub description: Option<String>,
109    pub parameters: Option<Value>,
110}
111
112#[derive(Serialize, Deserialize, Debug)]
113#[serde(tag = "role", rename_all = "lowercase")]
114pub enum ChatMessage {
115    Assistant {
116        #[serde(default)]
117        content: Option<MessageContent>,
118        #[serde(default, skip_serializing_if = "Vec::is_empty")]
119        tool_calls: Vec<ToolCall>,
120    },
121    User {
122        content: MessageContent,
123    },
124    System {
125        content: MessageContent,
126    },
127    Tool {
128        content: MessageContent,
129        tool_call_id: String,
130    },
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134#[serde(untagged)]
135pub enum MessageContent {
136    Plain(String),
137    Multipart(Vec<MessagePart>),
138}
139
140impl MessageContent {
141    pub const fn empty() -> Self {
142        MessageContent::Multipart(vec![])
143    }
144
145    pub fn push_part(&mut self, part: MessagePart) {
146        match self {
147            MessageContent::Plain(text) => {
148                *self =
149                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
150            }
151            MessageContent::Multipart(parts) if parts.is_empty() => match part {
152                MessagePart::Text { text } => *self = MessageContent::Plain(text),
153                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
154            },
155            MessageContent::Multipart(parts) => parts.push(part),
156        }
157    }
158}
159
160impl From<Vec<MessagePart>> for MessageContent {
161    fn from(mut parts: Vec<MessagePart>) -> Self {
162        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
163            MessageContent::Plain(std::mem::take(text))
164        } else {
165            MessageContent::Multipart(parts)
166        }
167    }
168}
169
170#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
171#[serde(tag = "type", rename_all = "snake_case")]
172pub enum MessagePart {
173    Text {
174        text: String,
175    },
176    #[serde(rename = "image_url")]
177    Image {
178        image_url: ImageUrl,
179    },
180}
181
182#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
183pub struct ImageUrl {
184    pub url: String,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub detail: Option<String>,
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCall {
191    pub id: String,
192    #[serde(flatten)]
193    pub content: ToolCallContent,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum ToolCallContent {
199    Function { function: FunctionContent },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct FunctionContent {
204    pub name: String,
205    pub arguments: String,
206}
207
208#[derive(Serialize, Debug)]
209pub struct ChatCompletionRequest {
210    pub model: String,
211    pub messages: Vec<ChatMessage>,
212    pub stream: bool,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub max_tokens: Option<i32>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub stop: Option<Vec<String>>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub temperature: Option<f32>,
219    #[serde(skip_serializing_if = "Vec::is_empty")]
220    pub tools: Vec<ToolDefinition>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub tool_choice: Option<ToolChoice>,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct ChatResponse {
227    pub id: String,
228    pub object: String,
229    pub created: u64,
230    pub model: String,
231    pub choices: Vec<ChoiceDelta>,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235pub struct ChoiceDelta {
236    pub index: u32,
237    pub delta: ResponseMessageDelta,
238    pub finish_reason: Option<String>,
239}
240
241#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
242pub struct ToolCallChunk {
243    pub index: usize,
244    pub id: Option<String>,
245
246    // There is also an optional `type` field that would determine if a
247    // function is there. Sometimes this streams in with the `function` before
248    // it streams in the `type`
249    pub function: Option<FunctionChunk>,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionChunk {
254    pub name: Option<String>,
255    pub arguments: Option<String>,
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct Usage {
260    pub prompt_tokens: u64,
261    pub completion_tokens: u64,
262    pub total_tokens: u64,
263}
264
265#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
266#[serde(transparent)]
267pub struct Capabilities(Vec<String>);
268
269impl Capabilities {
270    pub fn supports_tool_calls(&self) -> bool {
271        self.0.iter().any(|cap| cap == "tool_use")
272    }
273
274    pub fn supports_images(&self) -> bool {
275        self.0.iter().any(|cap| cap == "vision")
276    }
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct LmStudioError {
281    pub message: String,
282}
283
284#[derive(Serialize, Deserialize, Debug)]
285#[serde(untagged)]
286pub enum ResponseStreamResult {
287    Ok(ResponseStreamEvent),
288    Err { error: LmStudioError },
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292pub struct ResponseStreamEvent {
293    pub created: u32,
294    pub model: String,
295    pub object: String,
296    pub choices: Vec<ChoiceDelta>,
297    pub usage: Option<Usage>,
298}
299
300#[derive(Deserialize)]
301pub struct ListModelsResponse {
302    pub data: Vec<ModelEntry>,
303}
304
305#[derive(Clone, Debug, Deserialize, PartialEq)]
306pub struct ModelEntry {
307    pub id: String,
308    pub object: String,
309    pub r#type: ModelType,
310    pub publisher: String,
311    pub arch: Option<String>,
312    pub compatibility_type: CompatibilityType,
313    pub quantization: Option<String>,
314    pub state: ModelState,
315    pub max_context_length: Option<u64>,
316    pub loaded_context_length: Option<u64>,
317    #[serde(default)]
318    pub capabilities: Capabilities,
319}
320
321#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
322#[serde(rename_all = "lowercase")]
323pub enum ModelType {
324    Llm,
325    Embeddings,
326    Vlm,
327}
328
329#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
330#[serde(rename_all = "kebab-case")]
331pub enum ModelState {
332    Loaded,
333    Loading,
334    NotLoaded,
335}
336
337#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum CompatibilityType {
340    Gguf,
341    Mlx,
342}
343
344#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
345pub struct ResponseMessageDelta {
346    pub role: Option<Role>,
347    pub content: Option<String>,
348    #[serde(default, skip_serializing_if = "Option::is_none")]
349    pub reasoning_content: Option<String>,
350    #[serde(default, skip_serializing_if = "Option::is_none")]
351    pub tool_calls: Option<Vec<ToolCallChunk>>,
352}
353
354pub async fn complete(
355    client: &dyn HttpClient,
356    api_url: &str,
357    request: ChatCompletionRequest,
358) -> Result<ChatResponse> {
359    let uri = format!("{api_url}/chat/completions");
360    let request_builder = HttpRequest::builder()
361        .method(Method::POST)
362        .uri(uri)
363        .header("Content-Type", "application/json");
364
365    let serialized_request = serde_json::to_string(&request)?;
366    let request = request_builder.body(AsyncBody::from(serialized_request))?;
367
368    let mut response = client.send(request).await?;
369    if response.status().is_success() {
370        let mut body = Vec::new();
371        response.body_mut().read_to_end(&mut body).await?;
372        let response_message: ChatResponse = serde_json::from_slice(&body)?;
373        Ok(response_message)
374    } else {
375        let mut body = Vec::new();
376        response.body_mut().read_to_end(&mut body).await?;
377        let body_str = std::str::from_utf8(&body)?;
378        anyhow::bail!(
379            "Failed to connect to API: {} {}",
380            response.status(),
381            body_str
382        );
383    }
384}
385
386pub async fn stream_chat_completion(
387    client: &dyn HttpClient,
388    api_url: &str,
389    request: ChatCompletionRequest,
390) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
391    let uri = format!("{api_url}/chat/completions");
392    let request_builder = http::Request::builder()
393        .method(Method::POST)
394        .uri(uri)
395        .header("Content-Type", "application/json");
396
397    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
398    let mut response = client.send(request).await?;
399    if response.status().is_success() {
400        let reader = BufReader::new(response.into_body());
401        Ok(reader
402            .lines()
403            .filter_map(|line| async move {
404                match line {
405                    Ok(line) => {
406                        let line = line.strip_prefix("data: ")?;
407                        if line == "[DONE]" {
408                            None
409                        } else {
410                            match serde_json::from_str(line) {
411                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
412                                Ok(ResponseStreamResult::Err { error, .. }) => {
413                                    Some(Err(anyhow!(error.message)))
414                                }
415                                Err(error) => Some(Err(anyhow!(error))),
416                            }
417                        }
418                    }
419                    Err(error) => Some(Err(anyhow!(error))),
420                }
421            })
422            .boxed())
423    } else {
424        let mut body = String::new();
425        response.body_mut().read_to_string(&mut body).await?;
426        anyhow::bail!(
427            "Failed to connect to LM Studio API: {} {}",
428            response.status(),
429            body,
430        );
431    }
432}
433
434pub async fn get_models(
435    client: &dyn HttpClient,
436    api_url: &str,
437    _: Option<Duration>,
438) -> Result<Vec<ModelEntry>> {
439    let uri = format!("{api_url}/models");
440    let request_builder = HttpRequest::builder()
441        .method(Method::GET)
442        .uri(uri)
443        .header("Accept", "application/json");
444
445    let request = request_builder.body(AsyncBody::default())?;
446
447    let mut response = client.send(request).await?;
448
449    let mut body = String::new();
450    response.body_mut().read_to_string(&mut body).await?;
451
452    anyhow::ensure!(
453        response.status().is_success(),
454        "Failed to connect to LM Studio API: {} {}",
455        response.status(),
456        body,
457    );
458    let response: ListModelsResponse =
459        serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
460    Ok(response.data)
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_image_message_part_serialization() {
469        let image_part = MessagePart::Image {
470            image_url: ImageUrl {
471                url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
472                detail: None,
473            },
474        };
475
476        let json = serde_json::to_string(&image_part).unwrap();
477        println!("Serialized image part: {}", json);
478
479        // Verify the structure matches what LM Studio expects
480        let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
481        assert_eq!(json, expected_structure);
482    }
483
484    #[test]
485    fn test_text_message_part_serialization() {
486        let text_part = MessagePart::Text {
487            text: "Hello, world!".to_string(),
488        };
489
490        let json = serde_json::to_string(&text_part).unwrap();
491        println!("Serialized text part: {}", json);
492
493        let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
494        assert_eq!(json, expected_structure);
495    }
496}