1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, sync::Arc, time::Duration};
7
8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
9
10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
14 Assistant,
15 System,
16 Tool,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 "tool" => Ok(Self::Tool),
28 _ => anyhow::bail!("invalid role '{value}'"),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 Role::Tool => "tool".to_owned(),
40 }
41 }
42}
43
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Model {
47 pub name: String,
48 pub display_name: Option<String>,
49 pub max_tokens: usize,
50 pub supports_tool_calls: bool,
51}
52
53impl Model {
54 pub fn new(
55 name: &str,
56 display_name: Option<&str>,
57 max_tokens: Option<usize>,
58 supports_tool_calls: bool,
59 ) -> Self {
60 Self {
61 name: name.to_owned(),
62 display_name: display_name.map(|s| s.to_owned()),
63 max_tokens: max_tokens.unwrap_or(2048),
64 supports_tool_calls,
65 }
66 }
67
68 pub fn id(&self) -> &str {
69 &self.name
70 }
71
72 pub fn display_name(&self) -> &str {
73 self.display_name.as_ref().unwrap_or(&self.name)
74 }
75
76 pub fn max_token_count(&self) -> usize {
77 self.max_tokens
78 }
79
80 pub fn supports_tool_calls(&self) -> bool {
81 self.supports_tool_calls
82 }
83}
84
85#[derive(Debug, Serialize, Deserialize)]
86#[serde(untagged)]
87pub enum ToolChoice {
88 Auto,
89 Required,
90 None,
91 Other(ToolDefinition),
92}
93
94#[derive(Clone, Deserialize, Serialize, Debug)]
95#[serde(tag = "type", rename_all = "snake_case")]
96pub enum ToolDefinition {
97 #[allow(dead_code)]
98 Function { function: FunctionDefinition },
99}
100
101#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct FunctionDefinition {
103 pub name: String,
104 pub description: Option<String>,
105 pub parameters: Option<Value>,
106}
107
108#[derive(Serialize, Deserialize, Debug)]
109#[serde(tag = "role", rename_all = "lowercase")]
110pub enum ChatMessage {
111 Assistant {
112 #[serde(default)]
113 content: Option<String>,
114 #[serde(default, skip_serializing_if = "Vec::is_empty")]
115 tool_calls: Vec<ToolCall>,
116 },
117 User {
118 content: String,
119 },
120 System {
121 content: String,
122 },
123 Tool {
124 content: String,
125 tool_call_id: String,
126 },
127}
128
129#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
130pub struct ToolCall {
131 pub id: String,
132 #[serde(flatten)]
133 pub content: ToolCallContent,
134}
135
136#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
137#[serde(tag = "type", rename_all = "lowercase")]
138pub enum ToolCallContent {
139 Function { function: FunctionContent },
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143pub struct FunctionContent {
144 pub name: String,
145 pub arguments: String,
146}
147
148#[derive(Serialize, Debug)]
149pub struct ChatCompletionRequest {
150 pub model: String,
151 pub messages: Vec<ChatMessage>,
152 pub stream: bool,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub max_tokens: Option<i32>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 pub stop: Option<Vec<String>>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub temperature: Option<f32>,
159 #[serde(skip_serializing_if = "Vec::is_empty")]
160 pub tools: Vec<ToolDefinition>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 pub tool_choice: Option<ToolChoice>,
163}
164
165#[derive(Serialize, Deserialize, Debug)]
166pub struct ChatResponse {
167 pub id: String,
168 pub object: String,
169 pub created: u64,
170 pub model: String,
171 pub choices: Vec<ChoiceDelta>,
172}
173
174#[derive(Serialize, Deserialize, Debug)]
175pub struct ChoiceDelta {
176 pub index: u32,
177 pub delta: ResponseMessageDelta,
178 pub finish_reason: Option<String>,
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182pub struct ToolCallChunk {
183 pub index: usize,
184 pub id: Option<String>,
185
186 // There is also an optional `type` field that would determine if a
187 // function is there. Sometimes this streams in with the `function` before
188 // it streams in the `type`
189 pub function: Option<FunctionChunk>,
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193pub struct FunctionChunk {
194 pub name: Option<String>,
195 pub arguments: Option<String>,
196}
197
198#[derive(Serialize, Deserialize, Debug)]
199pub struct Usage {
200 pub prompt_tokens: u32,
201 pub completion_tokens: u32,
202 pub total_tokens: u32,
203}
204
205#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
206#[serde(transparent)]
207pub struct Capabilities(Vec<String>);
208
209impl Capabilities {
210 pub fn supports_tool_calls(&self) -> bool {
211 self.0.iter().any(|cap| cap == "tool_use")
212 }
213}
214
215#[derive(Serialize, Deserialize, Debug)]
216#[serde(untagged)]
217pub enum ResponseStreamResult {
218 Ok(ResponseStreamEvent),
219 Err { error: String },
220}
221
222#[derive(Serialize, Deserialize, Debug)]
223pub struct ResponseStreamEvent {
224 pub created: u32,
225 pub model: String,
226 pub object: String,
227 pub choices: Vec<ChoiceDelta>,
228 pub usage: Option<Usage>,
229}
230
231#[derive(Deserialize)]
232pub struct ListModelsResponse {
233 pub data: Vec<ModelEntry>,
234}
235
236#[derive(Clone, Debug, Deserialize, PartialEq)]
237pub struct ModelEntry {
238 pub id: String,
239 pub object: String,
240 pub r#type: ModelType,
241 pub publisher: String,
242 pub arch: Option<String>,
243 pub compatibility_type: CompatibilityType,
244 pub quantization: Option<String>,
245 pub state: ModelState,
246 pub max_context_length: Option<u32>,
247 pub loaded_context_length: Option<u32>,
248 #[serde(default)]
249 pub capabilities: Capabilities,
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
253#[serde(rename_all = "lowercase")]
254pub enum ModelType {
255 Llm,
256 Embeddings,
257 Vlm,
258}
259
260#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
261#[serde(rename_all = "kebab-case")]
262pub enum ModelState {
263 Loaded,
264 Loading,
265 NotLoaded,
266}
267
268#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
269#[serde(rename_all = "lowercase")]
270pub enum CompatibilityType {
271 Gguf,
272 Mlx,
273}
274
275#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
276pub struct ResponseMessageDelta {
277 pub role: Option<Role>,
278 pub content: Option<String>,
279 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub tool_calls: Option<Vec<ToolCallChunk>>,
281}
282
283pub async fn complete(
284 client: &dyn HttpClient,
285 api_url: &str,
286 request: ChatCompletionRequest,
287) -> Result<ChatResponse> {
288 let uri = format!("{api_url}/chat/completions");
289 let request_builder = HttpRequest::builder()
290 .method(Method::POST)
291 .uri(uri)
292 .header("Content-Type", "application/json");
293
294 let serialized_request = serde_json::to_string(&request)?;
295 let request = request_builder.body(AsyncBody::from(serialized_request))?;
296
297 let mut response = client.send(request).await?;
298 if response.status().is_success() {
299 let mut body = Vec::new();
300 response.body_mut().read_to_end(&mut body).await?;
301 let response_message: ChatResponse = serde_json::from_slice(&body)?;
302 Ok(response_message)
303 } else {
304 let mut body = Vec::new();
305 response.body_mut().read_to_end(&mut body).await?;
306 let body_str = std::str::from_utf8(&body)?;
307 anyhow::bail!(
308 "Failed to connect to API: {} {}",
309 response.status(),
310 body_str
311 );
312 }
313}
314
315pub async fn stream_chat_completion(
316 client: &dyn HttpClient,
317 api_url: &str,
318 request: ChatCompletionRequest,
319) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
320 let uri = format!("{api_url}/chat/completions");
321 let request_builder = http::Request::builder()
322 .method(Method::POST)
323 .uri(uri)
324 .header("Content-Type", "application/json");
325
326 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
327 let mut response = client.send(request).await?;
328 if response.status().is_success() {
329 let reader = BufReader::new(response.into_body());
330
331 Ok(reader
332 .lines()
333 .filter_map(|line| async move {
334 match line {
335 Ok(line) => {
336 let line = line.strip_prefix("data: ")?;
337 if line == "[DONE]" {
338 None
339 } else {
340 let result = serde_json::from_str(&line)
341 .context("Unable to parse chat completions response");
342 if let Err(ref e) = result {
343 eprintln!("Error parsing line: {e}\nLine content: '{line}'");
344 }
345 Some(result)
346 }
347 }
348 Err(e) => {
349 eprintln!("Error reading line: {e}");
350 Some(Err(e.into()))
351 }
352 }
353 })
354 .boxed())
355 } else {
356 let mut body = String::new();
357 response.body_mut().read_to_string(&mut body).await?;
358 anyhow::bail!(
359 "Failed to connect to LM Studio API: {} {}",
360 response.status(),
361 body,
362 );
363 }
364}
365
366pub async fn get_models(
367 client: &dyn HttpClient,
368 api_url: &str,
369 _: Option<Duration>,
370) -> Result<Vec<ModelEntry>> {
371 let uri = format!("{api_url}/models");
372 let request_builder = HttpRequest::builder()
373 .method(Method::GET)
374 .uri(uri)
375 .header("Accept", "application/json");
376
377 let request = request_builder.body(AsyncBody::default())?;
378
379 let mut response = client.send(request).await?;
380
381 let mut body = String::new();
382 response.body_mut().read_to_string(&mut body).await?;
383
384 anyhow::ensure!(
385 response.status().is_success(),
386 "Failed to connect to LM Studio API: {} {}",
387 response.status(),
388 body,
389 );
390 let response: ListModelsResponse =
391 serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
392 Ok(response.data)
393}
394
395/// Sends an empty request to LM Studio to trigger loading the model
396pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
397 let uri = format!("{api_url}/completions");
398 let request = HttpRequest::builder()
399 .method(Method::POST)
400 .uri(uri)
401 .header("Content-Type", "application/json")
402 .body(AsyncBody::from(serde_json::to_string(
403 &serde_json::json!({
404 "model": model,
405 "messages": [],
406 "stream": false,
407 "max_tokens": 0,
408 }),
409 )?))?;
410
411 let mut response = client.send(request).await?;
412
413 if response.status().is_success() {
414 Ok(())
415 } else {
416 let mut body = String::new();
417 response.body_mut().read_to_string(&mut body).await?;
418 anyhow::bail!(
419 "Failed to connect to LM Studio API: {} {}",
420 response.status(),
421 body,
422 );
423 }
424}