1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, time::Duration};
7
8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
9
10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
14 Assistant,
15 System,
16 Tool,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 "tool" => Ok(Self::Tool),
28 _ => anyhow::bail!("invalid role '{value}'"),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 Role::Tool => "tool".to_owned(),
40 }
41 }
42}
43
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Model {
47 pub name: String,
48 pub display_name: Option<String>,
49 pub max_tokens: u64,
50 pub supports_tool_calls: bool,
51 pub supports_images: bool,
52}
53
54impl Model {
55 pub fn new(
56 name: &str,
57 display_name: Option<&str>,
58 max_tokens: Option<u64>,
59 supports_tool_calls: bool,
60 supports_images: bool,
61 ) -> Self {
62 Self {
63 name: name.to_owned(),
64 display_name: display_name.map(|s| s.to_owned()),
65 max_tokens: max_tokens.unwrap_or(2048),
66 supports_tool_calls,
67 supports_images,
68 }
69 }
70
71 pub fn id(&self) -> &str {
72 &self.name
73 }
74
75 pub fn display_name(&self) -> &str {
76 self.display_name.as_ref().unwrap_or(&self.name)
77 }
78
79 pub fn max_token_count(&self) -> u64 {
80 self.max_tokens
81 }
82
83 pub fn supports_tool_calls(&self) -> bool {
84 self.supports_tool_calls
85 }
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89#[serde(rename_all = "lowercase")]
90pub enum ToolChoice {
91 Auto,
92 Required,
93 None,
94 #[serde(untagged)]
95 Other(ToolDefinition),
96}
97
98#[derive(Clone, Deserialize, Serialize, Debug)]
99#[serde(tag = "type", rename_all = "snake_case")]
100pub enum ToolDefinition {
101 #[allow(dead_code)]
102 Function { function: FunctionDefinition },
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize)]
106pub struct FunctionDefinition {
107 pub name: String,
108 pub description: Option<String>,
109 pub parameters: Option<Value>,
110}
111
112#[derive(Serialize, Deserialize, Debug)]
113#[serde(tag = "role", rename_all = "lowercase")]
114pub enum ChatMessage {
115 Assistant {
116 #[serde(default)]
117 content: Option<MessageContent>,
118 #[serde(default, skip_serializing_if = "Vec::is_empty")]
119 tool_calls: Vec<ToolCall>,
120 },
121 User {
122 content: MessageContent,
123 },
124 System {
125 content: MessageContent,
126 },
127 Tool {
128 content: MessageContent,
129 tool_call_id: String,
130 },
131}
132
133#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
134#[serde(untagged)]
135pub enum MessageContent {
136 Plain(String),
137 Multipart(Vec<MessagePart>),
138}
139
140impl MessageContent {
141 pub fn empty() -> Self {
142 MessageContent::Multipart(vec![])
143 }
144
145 pub fn push_part(&mut self, part: MessagePart) {
146 match self {
147 MessageContent::Plain(text) => {
148 *self =
149 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
150 }
151 MessageContent::Multipart(parts) if parts.is_empty() => match part {
152 MessagePart::Text { text } => *self = MessageContent::Plain(text),
153 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
154 },
155 MessageContent::Multipart(parts) => parts.push(part),
156 }
157 }
158}
159
160impl From<Vec<MessagePart>> for MessageContent {
161 fn from(mut parts: Vec<MessagePart>) -> Self {
162 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
163 MessageContent::Plain(std::mem::take(text))
164 } else {
165 MessageContent::Multipart(parts)
166 }
167 }
168}
169
170#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
171#[serde(tag = "type", rename_all = "snake_case")]
172pub enum MessagePart {
173 Text {
174 text: String,
175 },
176 #[serde(rename = "image_url")]
177 Image {
178 image_url: ImageUrl,
179 },
180}
181
182#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
183pub struct ImageUrl {
184 pub url: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub detail: Option<String>,
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCall {
191 pub id: String,
192 #[serde(flatten)]
193 pub content: ToolCallContent,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum ToolCallContent {
199 Function { function: FunctionContent },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct FunctionContent {
204 pub name: String,
205 pub arguments: String,
206}
207
208#[derive(Serialize, Debug)]
209pub struct ChatCompletionRequest {
210 pub model: String,
211 pub messages: Vec<ChatMessage>,
212 pub stream: bool,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 pub max_tokens: Option<i32>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub stop: Option<Vec<String>>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 pub temperature: Option<f32>,
219 #[serde(skip_serializing_if = "Vec::is_empty")]
220 pub tools: Vec<ToolDefinition>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub tool_choice: Option<ToolChoice>,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct ChatResponse {
227 pub id: String,
228 pub object: String,
229 pub created: u64,
230 pub model: String,
231 pub choices: Vec<ChoiceDelta>,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235pub struct ChoiceDelta {
236 pub index: u32,
237 pub delta: ResponseMessageDelta,
238 pub finish_reason: Option<String>,
239}
240
241#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
242pub struct ToolCallChunk {
243 pub index: usize,
244 pub id: Option<String>,
245
246 // There is also an optional `type` field that would determine if a
247 // function is there. Sometimes this streams in with the `function` before
248 // it streams in the `type`
249 pub function: Option<FunctionChunk>,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionChunk {
254 pub name: Option<String>,
255 pub arguments: Option<String>,
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct Usage {
260 pub prompt_tokens: u64,
261 pub completion_tokens: u64,
262 pub total_tokens: u64,
263}
264
265#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
266#[serde(transparent)]
267pub struct Capabilities(Vec<String>);
268
269impl Capabilities {
270 pub fn supports_tool_calls(&self) -> bool {
271 self.0.iter().any(|cap| cap == "tool_use")
272 }
273
274 pub fn supports_images(&self) -> bool {
275 self.0.iter().any(|cap| cap == "vision")
276 }
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct LmStudioError {
281 pub message: String,
282}
283
284#[derive(Serialize, Deserialize, Debug)]
285#[serde(untagged)]
286pub enum ResponseStreamResult {
287 Ok(ResponseStreamEvent),
288 Err { error: LmStudioError },
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292pub struct ResponseStreamEvent {
293 pub created: u32,
294 pub model: String,
295 pub object: String,
296 pub choices: Vec<ChoiceDelta>,
297 pub usage: Option<Usage>,
298}
299
300#[derive(Deserialize)]
301pub struct ListModelsResponse {
302 pub data: Vec<ModelEntry>,
303}
304
305#[derive(Clone, Debug, Deserialize, PartialEq)]
306pub struct ModelEntry {
307 pub id: String,
308 pub object: String,
309 pub r#type: ModelType,
310 pub publisher: String,
311 pub arch: Option<String>,
312 pub compatibility_type: CompatibilityType,
313 pub quantization: Option<String>,
314 pub state: ModelState,
315 pub max_context_length: Option<u64>,
316 pub loaded_context_length: Option<u64>,
317 #[serde(default)]
318 pub capabilities: Capabilities,
319}
320
321#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
322#[serde(rename_all = "lowercase")]
323pub enum ModelType {
324 Llm,
325 Embeddings,
326 Vlm,
327}
328
329#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
330#[serde(rename_all = "kebab-case")]
331pub enum ModelState {
332 Loaded,
333 Loading,
334 NotLoaded,
335}
336
337#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum CompatibilityType {
340 Gguf,
341 Mlx,
342}
343
344#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
345pub struct ResponseMessageDelta {
346 pub role: Option<Role>,
347 pub content: Option<String>,
348 #[serde(default, skip_serializing_if = "Option::is_none")]
349 pub reasoning_content: Option<String>,
350 #[serde(default, skip_serializing_if = "Option::is_none")]
351 pub tool_calls: Option<Vec<ToolCallChunk>>,
352}
353
354pub async fn complete(
355 client: &dyn HttpClient,
356 api_url: &str,
357 request: ChatCompletionRequest,
358) -> Result<ChatResponse> {
359 let uri = format!("{api_url}/chat/completions");
360 let request_builder = HttpRequest::builder()
361 .method(Method::POST)
362 .uri(uri)
363 .header("Content-Type", "application/json");
364
365 let serialized_request = serde_json::to_string(&request)?;
366 let request = request_builder.body(AsyncBody::from(serialized_request))?;
367
368 let mut response = client.send(request).await?;
369 if response.status().is_success() {
370 let mut body = Vec::new();
371 response.body_mut().read_to_end(&mut body).await?;
372 let response_message: ChatResponse = serde_json::from_slice(&body)?;
373 Ok(response_message)
374 } else {
375 let mut body = Vec::new();
376 response.body_mut().read_to_end(&mut body).await?;
377 let body_str = std::str::from_utf8(&body)?;
378 anyhow::bail!(
379 "Failed to connect to API: {} {}",
380 response.status(),
381 body_str
382 );
383 }
384}
385
386pub async fn stream_chat_completion(
387 client: &dyn HttpClient,
388 api_url: &str,
389 request: ChatCompletionRequest,
390) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
391 let uri = format!("{api_url}/chat/completions");
392 let request_builder = http::Request::builder()
393 .method(Method::POST)
394 .uri(uri)
395 .header("Content-Type", "application/json");
396
397 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
398 let mut response = client.send(request).await?;
399 if response.status().is_success() {
400 let reader = BufReader::new(response.into_body());
401 Ok(reader
402 .lines()
403 .filter_map(|line| async move {
404 match line {
405 Ok(line) => {
406 let line = line.strip_prefix("data: ")?;
407 if line == "[DONE]" {
408 None
409 } else {
410 match serde_json::from_str(line) {
411 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
412 Ok(ResponseStreamResult::Err { error, .. }) => {
413 Some(Err(anyhow!(error.message)))
414 }
415 Err(error) => Some(Err(anyhow!(error))),
416 }
417 }
418 }
419 Err(error) => Some(Err(anyhow!(error))),
420 }
421 })
422 .boxed())
423 } else {
424 let mut body = String::new();
425 response.body_mut().read_to_string(&mut body).await?;
426 anyhow::bail!(
427 "Failed to connect to LM Studio API: {} {}",
428 response.status(),
429 body,
430 );
431 }
432}
433
434pub async fn get_models(
435 client: &dyn HttpClient,
436 api_url: &str,
437 _: Option<Duration>,
438) -> Result<Vec<ModelEntry>> {
439 let uri = format!("{api_url}/models");
440 let request_builder = HttpRequest::builder()
441 .method(Method::GET)
442 .uri(uri)
443 .header("Accept", "application/json");
444
445 let request = request_builder.body(AsyncBody::default())?;
446
447 let mut response = client.send(request).await?;
448
449 let mut body = String::new();
450 response.body_mut().read_to_string(&mut body).await?;
451
452 anyhow::ensure!(
453 response.status().is_success(),
454 "Failed to connect to LM Studio API: {} {}",
455 response.status(),
456 body,
457 );
458 let response: ListModelsResponse =
459 serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
460 Ok(response.data)
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_image_message_part_serialization() {
469 let image_part = MessagePart::Image {
470 image_url: ImageUrl {
471 url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
472 detail: None,
473 },
474 };
475
476 let json = serde_json::to_string(&image_part).unwrap();
477 println!("Serialized image part: {}", json);
478
479 // Verify the structure matches what LM Studio expects
480 let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
481 assert_eq!(json, expected_structure);
482 }
483
484 #[test]
485 fn test_text_message_part_serialization() {
486 let text_part = MessagePart::Text {
487 text: "Hello, world!".to_string(),
488 };
489
490 let json = serde_json::to_string(&text_part).unwrap();
491 println!("Serialized text part: {}", json);
492
493 let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
494 assert_eq!(json, expected_structure);
495 }
496}