1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, time::Duration};
7
8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
9
10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
14 Assistant,
15 System,
16 Tool,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 "tool" => Ok(Self::Tool),
28 _ => anyhow::bail!("invalid role '{value}'"),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 Role::Tool => "tool".to_owned(),
40 }
41 }
42}
43
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Model {
47 pub name: String,
48 pub display_name: Option<String>,
49 pub max_tokens: u64,
50 pub supports_tool_calls: bool,
51 pub supports_images: bool,
52}
53
54impl Model {
55 pub fn new(
56 name: &str,
57 display_name: Option<&str>,
58 max_tokens: Option<u64>,
59 supports_tool_calls: bool,
60 supports_images: bool,
61 ) -> Self {
62 Self {
63 name: name.to_owned(),
64 display_name: display_name.map(|s| s.to_owned()),
65 max_tokens: max_tokens.unwrap_or(2048),
66 supports_tool_calls,
67 supports_images,
68 }
69 }
70
71 pub fn id(&self) -> &str {
72 &self.name
73 }
74
75 pub fn display_name(&self) -> &str {
76 self.display_name.as_ref().unwrap_or(&self.name)
77 }
78
79 pub fn max_token_count(&self) -> u64 {
80 self.max_tokens
81 }
82
83 pub fn supports_tool_calls(&self) -> bool {
84 self.supports_tool_calls
85 }
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89#[serde(untagged)]
90pub enum ToolChoice {
91 Auto,
92 Required,
93 None,
94 Other(ToolDefinition),
95}
96
97#[derive(Clone, Deserialize, Serialize, Debug)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub enum ToolDefinition {
100 #[allow(dead_code)]
101 Function { function: FunctionDefinition },
102}
103
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct FunctionDefinition {
106 pub name: String,
107 pub description: Option<String>,
108 pub parameters: Option<Value>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112#[serde(tag = "role", rename_all = "lowercase")]
113pub enum ChatMessage {
114 Assistant {
115 #[serde(default)]
116 content: Option<MessageContent>,
117 #[serde(default, skip_serializing_if = "Vec::is_empty")]
118 tool_calls: Vec<ToolCall>,
119 },
120 User {
121 content: MessageContent,
122 },
123 System {
124 content: MessageContent,
125 },
126 Tool {
127 content: MessageContent,
128 tool_call_id: String,
129 },
130}
131
132#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
133#[serde(untagged)]
134pub enum MessageContent {
135 Plain(String),
136 Multipart(Vec<MessagePart>),
137}
138
139impl MessageContent {
140 pub fn empty() -> Self {
141 MessageContent::Multipart(vec![])
142 }
143
144 pub fn push_part(&mut self, part: MessagePart) {
145 match self {
146 MessageContent::Plain(text) => {
147 *self =
148 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
149 }
150 MessageContent::Multipart(parts) if parts.is_empty() => match part {
151 MessagePart::Text { text } => *self = MessageContent::Plain(text),
152 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
153 },
154 MessageContent::Multipart(parts) => parts.push(part),
155 }
156 }
157}
158
159impl From<Vec<MessagePart>> for MessageContent {
160 fn from(mut parts: Vec<MessagePart>) -> Self {
161 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
162 MessageContent::Plain(std::mem::take(text))
163 } else {
164 MessageContent::Multipart(parts)
165 }
166 }
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "type", rename_all = "snake_case")]
171pub enum MessagePart {
172 Text {
173 text: String,
174 },
175 #[serde(rename = "image_url")]
176 Image {
177 image_url: ImageUrl,
178 },
179}
180
181#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
182pub struct ImageUrl {
183 pub url: String,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub detail: Option<String>,
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189pub struct ToolCall {
190 pub id: String,
191 #[serde(flatten)]
192 pub content: ToolCallContent,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "type", rename_all = "lowercase")]
197pub enum ToolCallContent {
198 Function { function: FunctionContent },
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct FunctionContent {
203 pub name: String,
204 pub arguments: String,
205}
206
207#[derive(Serialize, Debug)]
208pub struct ChatCompletionRequest {
209 pub model: String,
210 pub messages: Vec<ChatMessage>,
211 pub stream: bool,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub max_tokens: Option<i32>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 pub stop: Option<Vec<String>>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub temperature: Option<f32>,
218 #[serde(skip_serializing_if = "Vec::is_empty")]
219 pub tools: Vec<ToolDefinition>,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub tool_choice: Option<ToolChoice>,
222}
223
224#[derive(Serialize, Deserialize, Debug)]
225pub struct ChatResponse {
226 pub id: String,
227 pub object: String,
228 pub created: u64,
229 pub model: String,
230 pub choices: Vec<ChoiceDelta>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct ChoiceDelta {
235 pub index: u32,
236 pub delta: ResponseMessageDelta,
237 pub finish_reason: Option<String>,
238}
239
240#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
241pub struct ToolCallChunk {
242 pub index: usize,
243 pub id: Option<String>,
244
245 // There is also an optional `type` field that would determine if a
246 // function is there. Sometimes this streams in with the `function` before
247 // it streams in the `type`
248 pub function: Option<FunctionChunk>,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252pub struct FunctionChunk {
253 pub name: Option<String>,
254 pub arguments: Option<String>,
255}
256
257#[derive(Serialize, Deserialize, Debug)]
258pub struct Usage {
259 pub prompt_tokens: u64,
260 pub completion_tokens: u64,
261 pub total_tokens: u64,
262}
263
264#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
265#[serde(transparent)]
266pub struct Capabilities(Vec<String>);
267
268impl Capabilities {
269 pub fn supports_tool_calls(&self) -> bool {
270 self.0.iter().any(|cap| cap == "tool_use")
271 }
272
273 pub fn supports_images(&self) -> bool {
274 self.0.iter().any(|cap| cap == "vision")
275 }
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279#[serde(untagged)]
280pub enum ResponseStreamResult {
281 Ok(ResponseStreamEvent),
282 Err { error: String },
283}
284
285#[derive(Serialize, Deserialize, Debug)]
286pub struct ResponseStreamEvent {
287 pub created: u32,
288 pub model: String,
289 pub object: String,
290 pub choices: Vec<ChoiceDelta>,
291 pub usage: Option<Usage>,
292}
293
294#[derive(Deserialize)]
295pub struct ListModelsResponse {
296 pub data: Vec<ModelEntry>,
297}
298
299#[derive(Clone, Debug, Deserialize, PartialEq)]
300pub struct ModelEntry {
301 pub id: String,
302 pub object: String,
303 pub r#type: ModelType,
304 pub publisher: String,
305 pub arch: Option<String>,
306 pub compatibility_type: CompatibilityType,
307 pub quantization: Option<String>,
308 pub state: ModelState,
309 pub max_context_length: Option<u64>,
310 pub loaded_context_length: Option<u64>,
311 #[serde(default)]
312 pub capabilities: Capabilities,
313}
314
315#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
316#[serde(rename_all = "lowercase")]
317pub enum ModelType {
318 Llm,
319 Embeddings,
320 Vlm,
321}
322
323#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
324#[serde(rename_all = "kebab-case")]
325pub enum ModelState {
326 Loaded,
327 Loading,
328 NotLoaded,
329}
330
331#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
332#[serde(rename_all = "lowercase")]
333pub enum CompatibilityType {
334 Gguf,
335 Mlx,
336}
337
338#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
339pub struct ResponseMessageDelta {
340 pub role: Option<Role>,
341 pub content: Option<String>,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub reasoning_content: Option<String>,
344 #[serde(default, skip_serializing_if = "Option::is_none")]
345 pub tool_calls: Option<Vec<ToolCallChunk>>,
346}
347
348pub async fn complete(
349 client: &dyn HttpClient,
350 api_url: &str,
351 request: ChatCompletionRequest,
352) -> Result<ChatResponse> {
353 let uri = format!("{api_url}/chat/completions");
354 let request_builder = HttpRequest::builder()
355 .method(Method::POST)
356 .uri(uri)
357 .header("Content-Type", "application/json");
358
359 let serialized_request = serde_json::to_string(&request)?;
360 let request = request_builder.body(AsyncBody::from(serialized_request))?;
361
362 let mut response = client.send(request).await?;
363 if response.status().is_success() {
364 let mut body = Vec::new();
365 response.body_mut().read_to_end(&mut body).await?;
366 let response_message: ChatResponse = serde_json::from_slice(&body)?;
367 Ok(response_message)
368 } else {
369 let mut body = Vec::new();
370 response.body_mut().read_to_end(&mut body).await?;
371 let body_str = std::str::from_utf8(&body)?;
372 anyhow::bail!(
373 "Failed to connect to API: {} {}",
374 response.status(),
375 body_str
376 );
377 }
378}
379
380pub async fn stream_chat_completion(
381 client: &dyn HttpClient,
382 api_url: &str,
383 request: ChatCompletionRequest,
384) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
385 let uri = format!("{api_url}/chat/completions");
386 let request_builder = http::Request::builder()
387 .method(Method::POST)
388 .uri(uri)
389 .header("Content-Type", "application/json");
390
391 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
392 let mut response = client.send(request).await?;
393 if response.status().is_success() {
394 let reader = BufReader::new(response.into_body());
395
396 Ok(reader
397 .lines()
398 .filter_map(|line| async move {
399 match line {
400 Ok(line) => {
401 let line = line.strip_prefix("data: ")?;
402 if line == "[DONE]" {
403 None
404 } else {
405 let result = serde_json::from_str(&line)
406 .context("Unable to parse chat completions response");
407 if let Err(ref e) = result {
408 eprintln!("Error parsing line: {e}\nLine content: '{line}'");
409 }
410 Some(result)
411 }
412 }
413 Err(e) => {
414 eprintln!("Error reading line: {e}");
415 Some(Err(e.into()))
416 }
417 }
418 })
419 .boxed())
420 } else {
421 let mut body = String::new();
422 response.body_mut().read_to_string(&mut body).await?;
423 anyhow::bail!(
424 "Failed to connect to LM Studio API: {} {}",
425 response.status(),
426 body,
427 );
428 }
429}
430
431pub async fn get_models(
432 client: &dyn HttpClient,
433 api_url: &str,
434 _: Option<Duration>,
435) -> Result<Vec<ModelEntry>> {
436 let uri = format!("{api_url}/models");
437 let request_builder = HttpRequest::builder()
438 .method(Method::GET)
439 .uri(uri)
440 .header("Accept", "application/json");
441
442 let request = request_builder.body(AsyncBody::default())?;
443
444 let mut response = client.send(request).await?;
445
446 let mut body = String::new();
447 response.body_mut().read_to_string(&mut body).await?;
448
449 anyhow::ensure!(
450 response.status().is_success(),
451 "Failed to connect to LM Studio API: {} {}",
452 response.status(),
453 body,
454 );
455 let response: ListModelsResponse =
456 serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
457 Ok(response.data)
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_image_message_part_serialization() {
466 let image_part = MessagePart::Image {
467 image_url: ImageUrl {
468 url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
469 detail: None,
470 },
471 };
472
473 let json = serde_json::to_string(&image_part).unwrap();
474 println!("Serialized image part: {}", json);
475
476 // Verify the structure matches what LM Studio expects
477 let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
478 assert_eq!(json, expected_structure);
479 }
480
481 #[test]
482 fn test_text_message_part_serialization() {
483 let text_part = MessagePart::Text {
484 text: "Hello, world!".to_string(),
485 };
486
487 let json = serde_json::to_string(&text_part).unwrap();
488 println!("Serialized text part: {}", json);
489
490 let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
491 assert_eq!(json, expected_structure);
492 }
493}