1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, time::Duration};
7
8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
9
10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
14 Assistant,
15 System,
16 Tool,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 "tool" => Ok(Self::Tool),
28 _ => anyhow::bail!("invalid role '{value}'"),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 Role::Tool => "tool".to_owned(),
40 }
41 }
42}
43
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Model {
47 pub name: String,
48 pub display_name: Option<String>,
49 pub max_tokens: u64,
50 pub supports_tool_calls: bool,
51 pub supports_images: bool,
52}
53
54impl Model {
55 pub fn new(
56 name: &str,
57 display_name: Option<&str>,
58 max_tokens: Option<u64>,
59 supports_tool_calls: bool,
60 supports_images: bool,
61 ) -> Self {
62 Self {
63 name: name.to_owned(),
64 display_name: display_name.map(|s| s.to_owned()),
65 max_tokens: max_tokens.unwrap_or(2048),
66 supports_tool_calls,
67 supports_images,
68 }
69 }
70
71 pub fn id(&self) -> &str {
72 &self.name
73 }
74
75 pub fn display_name(&self) -> &str {
76 self.display_name.as_ref().unwrap_or(&self.name)
77 }
78
79 pub fn max_token_count(&self) -> u64 {
80 self.max_tokens
81 }
82
83 pub fn supports_tool_calls(&self) -> bool {
84 self.supports_tool_calls
85 }
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89#[serde(untagged)]
90pub enum ToolChoice {
91 Auto,
92 Required,
93 None,
94 Other(ToolDefinition),
95}
96
97#[derive(Clone, Deserialize, Serialize, Debug)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub enum ToolDefinition {
100 #[allow(dead_code)]
101 Function { function: FunctionDefinition },
102}
103
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct FunctionDefinition {
106 pub name: String,
107 pub description: Option<String>,
108 pub parameters: Option<Value>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
112#[serde(tag = "role", rename_all = "lowercase")]
113pub enum ChatMessage {
114 Assistant {
115 #[serde(default)]
116 content: Option<MessageContent>,
117 #[serde(default, skip_serializing_if = "Vec::is_empty")]
118 tool_calls: Vec<ToolCall>,
119 },
120 User {
121 content: MessageContent,
122 },
123 System {
124 content: MessageContent,
125 },
126 Tool {
127 content: MessageContent,
128 tool_call_id: String,
129 },
130}
131
132#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
133#[serde(untagged)]
134pub enum MessageContent {
135 Plain(String),
136 Multipart(Vec<MessagePart>),
137}
138
139impl MessageContent {
140 pub fn empty() -> Self {
141 MessageContent::Multipart(vec![])
142 }
143
144 pub fn push_part(&mut self, part: MessagePart) {
145 match self {
146 MessageContent::Plain(text) => {
147 *self =
148 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
149 }
150 MessageContent::Multipart(parts) if parts.is_empty() => match part {
151 MessagePart::Text { text } => *self = MessageContent::Plain(text),
152 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
153 },
154 MessageContent::Multipart(parts) => parts.push(part),
155 }
156 }
157}
158
159impl From<Vec<MessagePart>> for MessageContent {
160 fn from(mut parts: Vec<MessagePart>) -> Self {
161 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
162 MessageContent::Plain(std::mem::take(text))
163 } else {
164 MessageContent::Multipart(parts)
165 }
166 }
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "type", rename_all = "snake_case")]
171pub enum MessagePart {
172 Text {
173 text: String,
174 },
175 #[serde(rename = "image_url")]
176 Image {
177 image_url: ImageUrl,
178 },
179}
180
181#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
182pub struct ImageUrl {
183 pub url: String,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub detail: Option<String>,
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189pub struct ToolCall {
190 pub id: String,
191 #[serde(flatten)]
192 pub content: ToolCallContent,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "type", rename_all = "lowercase")]
197pub enum ToolCallContent {
198 Function { function: FunctionContent },
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct FunctionContent {
203 pub name: String,
204 pub arguments: String,
205}
206
207#[derive(Serialize, Debug)]
208pub struct ChatCompletionRequest {
209 pub model: String,
210 pub messages: Vec<ChatMessage>,
211 pub stream: bool,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub max_tokens: Option<i32>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 pub stop: Option<Vec<String>>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub temperature: Option<f32>,
218 #[serde(skip_serializing_if = "Vec::is_empty")]
219 pub tools: Vec<ToolDefinition>,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub tool_choice: Option<ToolChoice>,
222}
223
224#[derive(Serialize, Deserialize, Debug)]
225pub struct ChatResponse {
226 pub id: String,
227 pub object: String,
228 pub created: u64,
229 pub model: String,
230 pub choices: Vec<ChoiceDelta>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct ChoiceDelta {
235 pub index: u32,
236 pub delta: ResponseMessageDelta,
237 pub finish_reason: Option<String>,
238}
239
240#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
241pub struct ToolCallChunk {
242 pub index: usize,
243 pub id: Option<String>,
244
245 // There is also an optional `type` field that would determine if a
246 // function is there. Sometimes this streams in with the `function` before
247 // it streams in the `type`
248 pub function: Option<FunctionChunk>,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252pub struct FunctionChunk {
253 pub name: Option<String>,
254 pub arguments: Option<String>,
255}
256
257#[derive(Serialize, Deserialize, Debug)]
258pub struct Usage {
259 pub prompt_tokens: u64,
260 pub completion_tokens: u64,
261 pub total_tokens: u64,
262}
263
264#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
265#[serde(transparent)]
266pub struct Capabilities(Vec<String>);
267
268impl Capabilities {
269 pub fn supports_tool_calls(&self) -> bool {
270 self.0.iter().any(|cap| cap == "tool_use")
271 }
272
273 pub fn supports_images(&self) -> bool {
274 self.0.iter().any(|cap| cap == "vision")
275 }
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279pub struct LmStudioError {
280 pub message: String,
281}
282
283#[derive(Serialize, Deserialize, Debug)]
284#[serde(untagged)]
285pub enum ResponseStreamResult {
286 Ok(ResponseStreamEvent),
287 Err { error: LmStudioError },
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct ResponseStreamEvent {
292 pub created: u32,
293 pub model: String,
294 pub object: String,
295 pub choices: Vec<ChoiceDelta>,
296 pub usage: Option<Usage>,
297}
298
299#[derive(Deserialize)]
300pub struct ListModelsResponse {
301 pub data: Vec<ModelEntry>,
302}
303
304#[derive(Clone, Debug, Deserialize, PartialEq)]
305pub struct ModelEntry {
306 pub id: String,
307 pub object: String,
308 pub r#type: ModelType,
309 pub publisher: String,
310 pub arch: Option<String>,
311 pub compatibility_type: CompatibilityType,
312 pub quantization: Option<String>,
313 pub state: ModelState,
314 pub max_context_length: Option<u64>,
315 pub loaded_context_length: Option<u64>,
316 #[serde(default)]
317 pub capabilities: Capabilities,
318}
319
320#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
321#[serde(rename_all = "lowercase")]
322pub enum ModelType {
323 Llm,
324 Embeddings,
325 Vlm,
326}
327
328#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
329#[serde(rename_all = "kebab-case")]
330pub enum ModelState {
331 Loaded,
332 Loading,
333 NotLoaded,
334}
335
336#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
337#[serde(rename_all = "lowercase")]
338pub enum CompatibilityType {
339 Gguf,
340 Mlx,
341}
342
343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
344pub struct ResponseMessageDelta {
345 pub role: Option<Role>,
346 pub content: Option<String>,
347 #[serde(default, skip_serializing_if = "Option::is_none")]
348 pub reasoning_content: Option<String>,
349 #[serde(default, skip_serializing_if = "Option::is_none")]
350 pub tool_calls: Option<Vec<ToolCallChunk>>,
351}
352
353pub async fn complete(
354 client: &dyn HttpClient,
355 api_url: &str,
356 request: ChatCompletionRequest,
357) -> Result<ChatResponse> {
358 let uri = format!("{api_url}/chat/completions");
359 let request_builder = HttpRequest::builder()
360 .method(Method::POST)
361 .uri(uri)
362 .header("Content-Type", "application/json");
363
364 let serialized_request = serde_json::to_string(&request)?;
365 let request = request_builder.body(AsyncBody::from(serialized_request))?;
366
367 let mut response = client.send(request).await?;
368 if response.status().is_success() {
369 let mut body = Vec::new();
370 response.body_mut().read_to_end(&mut body).await?;
371 let response_message: ChatResponse = serde_json::from_slice(&body)?;
372 Ok(response_message)
373 } else {
374 let mut body = Vec::new();
375 response.body_mut().read_to_end(&mut body).await?;
376 let body_str = std::str::from_utf8(&body)?;
377 anyhow::bail!(
378 "Failed to connect to API: {} {}",
379 response.status(),
380 body_str
381 );
382 }
383}
384
385pub async fn stream_chat_completion(
386 client: &dyn HttpClient,
387 api_url: &str,
388 request: ChatCompletionRequest,
389) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
390 let uri = format!("{api_url}/chat/completions");
391 let request_builder = http::Request::builder()
392 .method(Method::POST)
393 .uri(uri)
394 .header("Content-Type", "application/json");
395
396 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
397 let mut response = client.send(request).await?;
398 if response.status().is_success() {
399 let reader = BufReader::new(response.into_body());
400 Ok(reader
401 .lines()
402 .filter_map(|line| async move {
403 match line {
404 Ok(line) => {
405 let line = line.strip_prefix("data: ")?;
406 if line == "[DONE]" {
407 None
408 } else {
409 match serde_json::from_str(line) {
410 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
411 Ok(ResponseStreamResult::Err { error, .. }) => {
412 Some(Err(anyhow!(error.message)))
413 }
414 Err(error) => Some(Err(anyhow!(error))),
415 }
416 }
417 }
418 Err(error) => Some(Err(anyhow!(error))),
419 }
420 })
421 .boxed())
422 } else {
423 let mut body = String::new();
424 response.body_mut().read_to_string(&mut body).await?;
425 anyhow::bail!(
426 "Failed to connect to LM Studio API: {} {}",
427 response.status(),
428 body,
429 );
430 }
431}
432
433pub async fn get_models(
434 client: &dyn HttpClient,
435 api_url: &str,
436 _: Option<Duration>,
437) -> Result<Vec<ModelEntry>> {
438 let uri = format!("{api_url}/models");
439 let request_builder = HttpRequest::builder()
440 .method(Method::GET)
441 .uri(uri)
442 .header("Accept", "application/json");
443
444 let request = request_builder.body(AsyncBody::default())?;
445
446 let mut response = client.send(request).await?;
447
448 let mut body = String::new();
449 response.body_mut().read_to_string(&mut body).await?;
450
451 anyhow::ensure!(
452 response.status().is_success(),
453 "Failed to connect to LM Studio API: {} {}",
454 response.status(),
455 body,
456 );
457 let response: ListModelsResponse =
458 serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
459 Ok(response.data)
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_image_message_part_serialization() {
468 let image_part = MessagePart::Image {
469 image_url: ImageUrl {
470 url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
471 detail: None,
472 },
473 };
474
475 let json = serde_json::to_string(&image_part).unwrap();
476 println!("Serialized image part: {}", json);
477
478 // Verify the structure matches what LM Studio expects
479 let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
480 assert_eq!(json, expected_structure);
481 }
482
483 #[test]
484 fn test_text_message_part_serialization() {
485 let text_part = MessagePart::Text {
486 text: "Hello, world!".to_string(),
487 };
488
489 let json = serde_json::to_string(&text_part).unwrap();
490 println!("Serialized text part: {}", json);
491
492 let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
493 assert_eq!(json, expected_structure);
494 }
495}