lmstudio.rs

  1use anyhow::{Context as _, Result};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::{Value, value::RawValue};
  6use std::{convert::TryFrom, sync::Arc, time::Duration};
  7
  8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => anyhow::bail!("invalid role '{value}'"),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub struct Model {
 47    pub name: String,
 48    pub display_name: Option<String>,
 49    pub max_tokens: usize,
 50}
 51
 52impl Model {
 53    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
 54        Self {
 55            name: name.to_owned(),
 56            display_name: display_name.map(|s| s.to_owned()),
 57            max_tokens: max_tokens.unwrap_or(2048),
 58        }
 59    }
 60
 61    pub fn id(&self) -> &str {
 62        &self.name
 63    }
 64
 65    pub fn display_name(&self) -> &str {
 66        self.display_name.as_ref().unwrap_or(&self.name)
 67    }
 68
 69    pub fn max_token_count(&self) -> usize {
 70        self.max_tokens
 71    }
 72}
 73#[derive(Serialize, Deserialize, Debug)]
 74#[serde(tag = "role", rename_all = "lowercase")]
 75pub enum ChatMessage {
 76    Assistant {
 77        #[serde(default)]
 78        content: Option<String>,
 79        #[serde(default)]
 80        tool_calls: Option<Vec<LmStudioToolCall>>,
 81    },
 82    User {
 83        content: String,
 84    },
 85    System {
 86        content: String,
 87    },
 88}
 89
 90#[derive(Serialize, Deserialize, Debug)]
 91#[serde(rename_all = "lowercase")]
 92pub enum LmStudioToolCall {
 93    Function(LmStudioFunctionCall),
 94}
 95
 96#[derive(Serialize, Deserialize, Debug)]
 97pub struct LmStudioFunctionCall {
 98    pub name: String,
 99    pub arguments: Box<RawValue>,
100}
101
102#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
103pub struct LmStudioFunctionTool {
104    pub name: String,
105    pub description: Option<String>,
106    pub parameters: Option<Value>,
107}
108
109#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
110#[serde(tag = "type", rename_all = "lowercase")]
111pub enum LmStudioTool {
112    Function { function: LmStudioFunctionTool },
113}
114
115#[derive(Serialize, Debug)]
116pub struct ChatCompletionRequest {
117    pub model: String,
118    pub messages: Vec<ChatMessage>,
119    pub stream: bool,
120    pub max_tokens: Option<i32>,
121    pub stop: Option<Vec<String>>,
122    pub temperature: Option<f32>,
123    pub tools: Vec<LmStudioTool>,
124}
125
126#[derive(Serialize, Deserialize, Debug)]
127pub struct ChatResponse {
128    pub id: String,
129    pub object: String,
130    pub created: u64,
131    pub model: String,
132    pub choices: Vec<ChoiceDelta>,
133}
134
135#[derive(Serialize, Deserialize, Debug)]
136pub struct ChoiceDelta {
137    pub index: u32,
138    #[serde(default)]
139    pub delta: serde_json::Value,
140    pub finish_reason: Option<String>,
141}
142
143#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
144pub struct ToolCallChunk {
145    pub index: usize,
146    pub id: Option<String>,
147
148    // There is also an optional `type` field that would determine if a
149    // function is there. Sometimes this streams in with the `function` before
150    // it streams in the `type`
151    pub function: Option<FunctionChunk>,
152}
153
154#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
155pub struct FunctionChunk {
156    pub name: Option<String>,
157    pub arguments: Option<String>,
158}
159
160#[derive(Serialize, Deserialize, Debug)]
161pub struct Usage {
162    pub prompt_tokens: u32,
163    pub completion_tokens: u32,
164    pub total_tokens: u32,
165}
166
167#[derive(Serialize, Deserialize, Debug)]
168#[serde(untagged)]
169pub enum ResponseStreamResult {
170    Ok(ResponseStreamEvent),
171    Err { error: String },
172}
173
174#[derive(Serialize, Deserialize, Debug)]
175pub struct ResponseStreamEvent {
176    pub created: u32,
177    pub model: String,
178    pub choices: Vec<ChoiceDelta>,
179    pub usage: Option<Usage>,
180}
181
182#[derive(Serialize, Deserialize)]
183pub struct ListModelsResponse {
184    pub data: Vec<ModelEntry>,
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
188pub struct ModelEntry {
189    pub id: String,
190    pub object: String,
191    pub r#type: ModelType,
192    pub publisher: String,
193    pub arch: Option<String>,
194    pub compatibility_type: CompatibilityType,
195    pub quantization: Option<String>,
196    pub state: ModelState,
197    pub max_context_length: Option<u32>,
198    pub loaded_context_length: Option<u32>,
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
202#[serde(rename_all = "lowercase")]
203pub enum ModelType {
204    Llm,
205    Embeddings,
206    Vlm,
207}
208
209#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
210#[serde(rename_all = "kebab-case")]
211pub enum ModelState {
212    Loaded,
213    Loading,
214    NotLoaded,
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
218#[serde(rename_all = "lowercase")]
219pub enum CompatibilityType {
220    Gguf,
221    Mlx,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225pub struct ResponseMessageDelta {
226    pub role: Option<Role>,
227    pub content: Option<String>,
228    #[serde(default, skip_serializing_if = "Option::is_none")]
229    pub tool_calls: Option<Vec<ToolCallChunk>>,
230}
231
232pub async fn complete(
233    client: &dyn HttpClient,
234    api_url: &str,
235    request: ChatCompletionRequest,
236) -> Result<ChatResponse> {
237    let uri = format!("{api_url}/chat/completions");
238    let request_builder = HttpRequest::builder()
239        .method(Method::POST)
240        .uri(uri)
241        .header("Content-Type", "application/json");
242
243    let serialized_request = serde_json::to_string(&request)?;
244    let request = request_builder.body(AsyncBody::from(serialized_request))?;
245
246    let mut response = client.send(request).await?;
247    if response.status().is_success() {
248        let mut body = Vec::new();
249        response.body_mut().read_to_end(&mut body).await?;
250        let response_message: ChatResponse = serde_json::from_slice(&body)?;
251        Ok(response_message)
252    } else {
253        let mut body = Vec::new();
254        response.body_mut().read_to_end(&mut body).await?;
255        let body_str = std::str::from_utf8(&body)?;
256        anyhow::bail!(
257            "Failed to connect to API: {} {}",
258            response.status(),
259            body_str
260        );
261    }
262}
263
264pub async fn stream_chat_completion(
265    client: &dyn HttpClient,
266    api_url: &str,
267    request: ChatCompletionRequest,
268) -> Result<BoxStream<'static, Result<ChatResponse>>> {
269    let uri = format!("{api_url}/chat/completions");
270    let request_builder = http::Request::builder()
271        .method(Method::POST)
272        .uri(uri)
273        .header("Content-Type", "application/json");
274
275    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
276    let mut response = client.send(request).await?;
277    if response.status().is_success() {
278        let reader = BufReader::new(response.into_body());
279
280        Ok(reader
281            .lines()
282            .filter_map(|line| async move {
283                match line {
284                    Ok(line) => {
285                        let line = line.strip_prefix("data: ")?;
286                        if line == "[DONE]" {
287                            None
288                        } else {
289                            let result = serde_json::from_str(&line)
290                                .context("Unable to parse chat completions response");
291                            if let Err(ref e) = result {
292                                eprintln!("Error parsing line: {e}\nLine content: '{line}'");
293                            }
294                            Some(result)
295                        }
296                    }
297                    Err(e) => {
298                        eprintln!("Error reading line: {e}");
299                        Some(Err(e.into()))
300                    }
301                }
302            })
303            .boxed())
304    } else {
305        let mut body = String::new();
306        response.body_mut().read_to_string(&mut body).await?;
307        anyhow::bail!(
308            "Failed to connect to LM Studio API: {} {}",
309            response.status(),
310            body,
311        );
312    }
313}
314
315pub async fn get_models(
316    client: &dyn HttpClient,
317    api_url: &str,
318    _: Option<Duration>,
319) -> Result<Vec<ModelEntry>> {
320    let uri = format!("{api_url}/models");
321    let request_builder = HttpRequest::builder()
322        .method(Method::GET)
323        .uri(uri)
324        .header("Accept", "application/json");
325
326    let request = request_builder.body(AsyncBody::default())?;
327
328    let mut response = client.send(request).await?;
329
330    let mut body = String::new();
331    response.body_mut().read_to_string(&mut body).await?;
332
333    anyhow::ensure!(
334        response.status().is_success(),
335        "Failed to connect to LM Studio API: {} {}",
336        response.status(),
337        body,
338    );
339    let response: ListModelsResponse =
340        serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
341    Ok(response.data)
342}
343
344/// Sends an empty request to LM Studio to trigger loading the model
345pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
346    let uri = format!("{api_url}/completions");
347    let request = HttpRequest::builder()
348        .method(Method::POST)
349        .uri(uri)
350        .header("Content-Type", "application/json")
351        .body(AsyncBody::from(serde_json::to_string(
352            &serde_json::json!({
353                "model": model,
354                "messages": [],
355                "stream": false,
356                "max_tokens": 0,
357            }),
358        )?))?;
359
360    let mut response = client.send(request).await?;
361
362    if response.status().is_success() {
363        Ok(())
364    } else {
365        let mut body = String::new();
366        response.body_mut().read_to_string(&mut body).await?;
367        anyhow::bail!(
368            "Failed to connect to LM Studio API: {} {}",
369            response.status(),
370            body,
371        );
372    }
373}