lmstudio.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, time::Duration};
  7
  8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => anyhow::bail!("invalid role '{value}'"),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub struct Model {
 47    pub name: String,
 48    pub display_name: Option<String>,
 49    pub max_tokens: u64,
 50    pub supports_tool_calls: bool,
 51    pub supports_images: bool,
 52}
 53
 54impl Model {
 55    pub fn new(
 56        name: &str,
 57        display_name: Option<&str>,
 58        max_tokens: Option<u64>,
 59        supports_tool_calls: bool,
 60        supports_images: bool,
 61    ) -> Self {
 62        Self {
 63            name: name.to_owned(),
 64            display_name: display_name.map(|s| s.to_owned()),
 65            max_tokens: max_tokens.unwrap_or(2048),
 66            supports_tool_calls,
 67            supports_images,
 68        }
 69    }
 70
 71    pub fn id(&self) -> &str {
 72        &self.name
 73    }
 74
 75    pub fn display_name(&self) -> &str {
 76        self.display_name.as_ref().unwrap_or(&self.name)
 77    }
 78
 79    pub fn max_token_count(&self) -> u64 {
 80        self.max_tokens
 81    }
 82
 83    pub fn supports_tool_calls(&self) -> bool {
 84        self.supports_tool_calls
 85    }
 86}
 87
 88#[derive(Debug, Serialize, Deserialize)]
 89#[serde(rename_all = "lowercase")]
 90pub enum ToolChoice {
 91    Auto,
 92    Required,
 93    None,
 94    #[serde(untagged)]
 95    Other(ToolDefinition),
 96}
 97
 98#[derive(Clone, Deserialize, Serialize, Debug)]
 99#[serde(tag = "type", rename_all = "snake_case")]
100pub enum ToolDefinition {
101    #[allow(dead_code)]
102    Function { function: FunctionDefinition },
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize)]
106pub struct FunctionDefinition {
107    pub name: String,
108    pub description: Option<String>,
109    pub parameters: Option<Value>,
110}
111
112#[derive(Serialize, Deserialize, Debug)]
113#[serde(tag = "role", rename_all = "lowercase")]
114pub enum ChatMessage {
115    Assistant {
116        #[serde(default)]
117        content: Option<MessageContent>,
118        #[serde(default, skip_serializing_if = "Vec::is_empty")]
119        tool_calls: Vec<ToolCall>,
120    },
121    User {
122        content: MessageContent,
123    },
124    System {
125        content: MessageContent,
126    },
127    Tool {
128        content: MessageContent,
129        tool_call_id: String,
130    },
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134#[serde(untagged)]
135pub enum MessageContent {
136    Plain(String),
137    Multipart(Vec<MessagePart>),
138}
139
140impl MessageContent {
141    pub fn empty() -> Self {
142        MessageContent::Multipart(vec![])
143    }
144
145    pub fn push_part(&mut self, part: MessagePart) {
146        match self {
147            MessageContent::Plain(text) => {
148                *self =
149                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
150            }
151            MessageContent::Multipart(parts) if parts.is_empty() => match part {
152                MessagePart::Text { text } => *self = MessageContent::Plain(text),
153                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
154            },
155            MessageContent::Multipart(parts) => parts.push(part),
156        }
157    }
158}
159
160impl From<Vec<MessagePart>> for MessageContent {
161    fn from(mut parts: Vec<MessagePart>) -> Self {
162        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
163            MessageContent::Plain(std::mem::take(text))
164        } else {
165            MessageContent::Multipart(parts)
166        }
167    }
168}
169
170#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
171#[serde(tag = "type", rename_all = "snake_case")]
172pub enum MessagePart {
173    Text {
174        text: String,
175    },
176    #[serde(rename = "image_url")]
177    Image {
178        image_url: ImageUrl,
179    },
180}
181
182#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
183pub struct ImageUrl {
184    pub url: String,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub detail: Option<String>,
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCall {
191    pub id: String,
192    #[serde(flatten)]
193    pub content: ToolCallContent,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum ToolCallContent {
199    Function { function: FunctionContent },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct FunctionContent {
204    pub name: String,
205    pub arguments: String,
206}
207
208#[derive(Serialize, Debug)]
209pub struct ChatCompletionRequest {
210    pub model: String,
211    pub messages: Vec<ChatMessage>,
212    pub stream: bool,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub max_tokens: Option<i32>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub stop: Option<Vec<String>>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub temperature: Option<f32>,
219    #[serde(skip_serializing_if = "Vec::is_empty")]
220    pub tools: Vec<ToolDefinition>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub tool_choice: Option<ToolChoice>,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct ChatResponse {
227    pub id: String,
228    pub object: String,
229    pub created: u64,
230    pub model: String,
231    pub choices: Vec<ChoiceDelta>,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235pub struct ChoiceDelta {
236    pub index: u32,
237    pub delta: ResponseMessageDelta,
238    pub finish_reason: Option<String>,
239}
240
241#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
242pub struct ToolCallChunk {
243    pub index: usize,
244    pub id: Option<String>,
245
246    // There is also an optional `type` field that would determine if a
247    // function is there. Sometimes this streams in with the `function` before
248    // it streams in the `type`
249    pub function: Option<FunctionChunk>,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionChunk {
254    pub name: Option<String>,
255    pub arguments: Option<String>,
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct Usage {
260    pub prompt_tokens: u64,
261    pub completion_tokens: u64,
262    pub total_tokens: u64,
263}
264
265#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
266#[serde(transparent)]
267pub struct Capabilities(Vec<String>);
268
269impl Capabilities {
270    pub fn supports_tool_calls(&self) -> bool {
271        self.0.iter().any(|cap| cap == "tool_use")
272    }
273
274    pub fn supports_images(&self) -> bool {
275        self.0.iter().any(|cap| cap == "vision")
276    }
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct LmStudioError {
281    pub message: String,
282}
283
284#[derive(Serialize, Deserialize, Debug)]
285#[serde(untagged)]
286pub enum ResponseStreamResult {
287    Ok(ResponseStreamEvent),
288    Err { error: LmStudioError },
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292pub struct ResponseStreamEvent {
293    pub created: u32,
294    pub model: String,
295    pub object: String,
296    pub choices: Vec<ChoiceDelta>,
297    pub usage: Option<Usage>,
298}
299
300#[derive(Deserialize)]
301pub struct ListModelsResponse {
302    pub data: Vec<ModelEntry>,
303}
304
305#[derive(Clone, Debug, Deserialize, PartialEq)]
306pub struct ModelEntry {
307    pub id: String,
308    pub object: String,
309    pub r#type: ModelType,
310    pub publisher: String,
311    pub arch: Option<String>,
312    pub compatibility_type: CompatibilityType,
313    pub quantization: Option<String>,
314    pub state: ModelState,
315    pub max_context_length: Option<u64>,
316    pub loaded_context_length: Option<u64>,
317    #[serde(default)]
318    pub capabilities: Capabilities,
319}
320
321#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
322#[serde(rename_all = "lowercase")]
323pub enum ModelType {
324    Llm,
325    Embeddings,
326    Vlm,
327}
328
329#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
330#[serde(rename_all = "kebab-case")]
331pub enum ModelState {
332    Loaded,
333    Loading,
334    NotLoaded,
335}
336
337#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum CompatibilityType {
340    Gguf,
341    Mlx,
342}
343
344#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
345pub struct ResponseMessageDelta {
346    pub role: Option<Role>,
347    pub content: Option<String>,
348    #[serde(default, skip_serializing_if = "Option::is_none")]
349    pub reasoning_content: Option<String>,
350    #[serde(default, skip_serializing_if = "Option::is_none")]
351    pub tool_calls: Option<Vec<ToolCallChunk>>,
352}
353
354pub async fn complete(
355    client: &dyn HttpClient,
356    api_url: &str,
357    api_key: Option<&str>,
358    request: ChatCompletionRequest,
359) -> Result<ChatResponse> {
360    let uri = format!("{api_url}/chat/completions");
361    let mut request_builder = HttpRequest::builder()
362        .method(Method::POST)
363        .uri(uri)
364        .header("Content-Type", "application/json");
365
366    if let Some(api_key) = api_key {
367        request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
368    }
369
370    let serialized_request = serde_json::to_string(&request)?;
371    let request = request_builder.body(AsyncBody::from(serialized_request))?;
372
373    let mut response = client.send(request).await?;
374    if response.status().is_success() {
375        let mut body = Vec::new();
376        response.body_mut().read_to_end(&mut body).await?;
377        let response_message: ChatResponse = serde_json::from_slice(&body)?;
378        Ok(response_message)
379    } else {
380        let mut body = Vec::new();
381        response.body_mut().read_to_end(&mut body).await?;
382        let body_str = std::str::from_utf8(&body)?;
383        anyhow::bail!(
384            "Failed to connect to API: {} {}",
385            response.status(),
386            body_str
387        );
388    }
389}
390
391pub async fn stream_chat_completion(
392    client: &dyn HttpClient,
393    api_url: &str,
394    api_key: Option<&str>,
395    request: ChatCompletionRequest,
396) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
397    let uri = format!("{api_url}/chat/completions");
398    let mut request_builder = http::Request::builder()
399        .method(Method::POST)
400        .uri(uri)
401        .header("Content-Type", "application/json");
402
403    if let Some(api_key) = api_key {
404        request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
405    }
406
407    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
408    let mut response = client.send(request).await?;
409    if response.status().is_success() {
410        let reader = BufReader::new(response.into_body());
411        Ok(reader
412            .lines()
413            .filter_map(|line| async move {
414                match line {
415                    Ok(line) => {
416                        let line = line.strip_prefix("data: ")?;
417                        if line == "[DONE]" {
418                            None
419                        } else {
420                            match serde_json::from_str(line) {
421                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
422                                Ok(ResponseStreamResult::Err { error, .. }) => {
423                                    Some(Err(anyhow!(error.message)))
424                                }
425                                Err(error) => Some(Err(anyhow!(error))),
426                            }
427                        }
428                    }
429                    Err(error) => Some(Err(anyhow!(error))),
430                }
431            })
432            .boxed())
433    } else {
434        let mut body = String::new();
435        response.body_mut().read_to_string(&mut body).await?;
436        anyhow::bail!(
437            "Failed to connect to LM Studio API: {} {}",
438            response.status(),
439            body,
440        );
441    }
442}
443
444pub async fn get_models(
445    client: &dyn HttpClient,
446    api_url: &str,
447    api_key: Option<&str>,
448    _: Option<Duration>,
449) -> Result<Vec<ModelEntry>> {
450    let uri = format!("{api_url}/models");
451    let mut request_builder = HttpRequest::builder()
452        .method(Method::GET)
453        .uri(uri)
454        .header("Accept", "application/json");
455
456    if let Some(api_key) = api_key {
457        request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
458    }
459
460    let request = request_builder.body(AsyncBody::default())?;
461
462    let mut response = client.send(request).await?;
463
464    let mut body = String::new();
465    response.body_mut().read_to_string(&mut body).await?;
466
467    anyhow::ensure!(
468        response.status().is_success(),
469        "Failed to connect to LM Studio API: {} {}",
470        response.status(),
471        body,
472    );
473    let response: ListModelsResponse =
474        serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
475    Ok(response.data)
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_image_message_part_serialization() {
484        let image_part = MessagePart::Image {
485            image_url: ImageUrl {
486                url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
487                detail: None,
488            },
489        };
490
491        let json = serde_json::to_string(&image_part).unwrap();
492        println!("Serialized image part: {}", json);
493
494        // Verify the structure matches what LM Studio expects
495        let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
496        assert_eq!(json, expected_structure);
497    }
498
499    #[test]
500    fn test_text_message_part_serialization() {
501        let text_part = MessagePart::Text {
502            text: "Hello, world!".to_string(),
503        };
504
505        let json = serde_json::to_string(&text_part).unwrap();
506        println!("Serialized text part: {}", json);
507
508        let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
509        assert_eq!(json, expected_structure);
510    }
511}