1use anyhow::{Context as _, Result};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, value::RawValue};
6use std::{convert::TryFrom, sync::Arc, time::Duration};
7
8pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
9
10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
14 Assistant,
15 System,
16 Tool,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 "tool" => Ok(Self::Tool),
28 _ => anyhow::bail!("invalid role '{value}'"),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 Role::Tool => "tool".to_owned(),
40 }
41 }
42}
43
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Model {
47 pub name: String,
48 pub display_name: Option<String>,
49 pub max_tokens: usize,
50}
51
52impl Model {
53 pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
54 Self {
55 name: name.to_owned(),
56 display_name: display_name.map(|s| s.to_owned()),
57 max_tokens: max_tokens.unwrap_or(2048),
58 }
59 }
60
61 pub fn id(&self) -> &str {
62 &self.name
63 }
64
65 pub fn display_name(&self) -> &str {
66 self.display_name.as_ref().unwrap_or(&self.name)
67 }
68
69 pub fn max_token_count(&self) -> usize {
70 self.max_tokens
71 }
72}
73#[derive(Serialize, Deserialize, Debug)]
74#[serde(tag = "role", rename_all = "lowercase")]
75pub enum ChatMessage {
76 Assistant {
77 #[serde(default)]
78 content: Option<String>,
79 #[serde(default)]
80 tool_calls: Option<Vec<LmStudioToolCall>>,
81 },
82 User {
83 content: String,
84 },
85 System {
86 content: String,
87 },
88}
89
90#[derive(Serialize, Deserialize, Debug)]
91#[serde(rename_all = "lowercase")]
92pub enum LmStudioToolCall {
93 Function(LmStudioFunctionCall),
94}
95
96#[derive(Serialize, Deserialize, Debug)]
97pub struct LmStudioFunctionCall {
98 pub name: String,
99 pub arguments: Box<RawValue>,
100}
101
102#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
103pub struct LmStudioFunctionTool {
104 pub name: String,
105 pub description: Option<String>,
106 pub parameters: Option<Value>,
107}
108
109#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
110#[serde(tag = "type", rename_all = "lowercase")]
111pub enum LmStudioTool {
112 Function { function: LmStudioFunctionTool },
113}
114
115#[derive(Serialize, Debug)]
116pub struct ChatCompletionRequest {
117 pub model: String,
118 pub messages: Vec<ChatMessage>,
119 pub stream: bool,
120 pub max_tokens: Option<i32>,
121 pub stop: Option<Vec<String>>,
122 pub temperature: Option<f32>,
123 pub tools: Vec<LmStudioTool>,
124}
125
126#[derive(Serialize, Deserialize, Debug)]
127pub struct ChatResponse {
128 pub id: String,
129 pub object: String,
130 pub created: u64,
131 pub model: String,
132 pub choices: Vec<ChoiceDelta>,
133}
134
135#[derive(Serialize, Deserialize, Debug)]
136pub struct ChoiceDelta {
137 pub index: u32,
138 #[serde(default)]
139 pub delta: serde_json::Value,
140 pub finish_reason: Option<String>,
141}
142
143#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
144pub struct ToolCallChunk {
145 pub index: usize,
146 pub id: Option<String>,
147
148 // There is also an optional `type` field that would determine if a
149 // function is there. Sometimes this streams in with the `function` before
150 // it streams in the `type`
151 pub function: Option<FunctionChunk>,
152}
153
154#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
155pub struct FunctionChunk {
156 pub name: Option<String>,
157 pub arguments: Option<String>,
158}
159
160#[derive(Serialize, Deserialize, Debug)]
161pub struct Usage {
162 pub prompt_tokens: u32,
163 pub completion_tokens: u32,
164 pub total_tokens: u32,
165}
166
167#[derive(Serialize, Deserialize, Debug)]
168#[serde(untagged)]
169pub enum ResponseStreamResult {
170 Ok(ResponseStreamEvent),
171 Err { error: String },
172}
173
174#[derive(Serialize, Deserialize, Debug)]
175pub struct ResponseStreamEvent {
176 pub created: u32,
177 pub model: String,
178 pub choices: Vec<ChoiceDelta>,
179 pub usage: Option<Usage>,
180}
181
182#[derive(Serialize, Deserialize)]
183pub struct ListModelsResponse {
184 pub data: Vec<ModelEntry>,
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
188pub struct ModelEntry {
189 pub id: String,
190 pub object: String,
191 pub r#type: ModelType,
192 pub publisher: String,
193 pub arch: Option<String>,
194 pub compatibility_type: CompatibilityType,
195 pub quantization: Option<String>,
196 pub state: ModelState,
197 pub max_context_length: Option<u32>,
198 pub loaded_context_length: Option<u32>,
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
202#[serde(rename_all = "lowercase")]
203pub enum ModelType {
204 Llm,
205 Embeddings,
206 Vlm,
207}
208
209#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
210#[serde(rename_all = "kebab-case")]
211pub enum ModelState {
212 Loaded,
213 Loading,
214 NotLoaded,
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
218#[serde(rename_all = "lowercase")]
219pub enum CompatibilityType {
220 Gguf,
221 Mlx,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225pub struct ResponseMessageDelta {
226 pub role: Option<Role>,
227 pub content: Option<String>,
228 #[serde(default, skip_serializing_if = "Option::is_none")]
229 pub tool_calls: Option<Vec<ToolCallChunk>>,
230}
231
232pub async fn complete(
233 client: &dyn HttpClient,
234 api_url: &str,
235 request: ChatCompletionRequest,
236) -> Result<ChatResponse> {
237 let uri = format!("{api_url}/chat/completions");
238 let request_builder = HttpRequest::builder()
239 .method(Method::POST)
240 .uri(uri)
241 .header("Content-Type", "application/json");
242
243 let serialized_request = serde_json::to_string(&request)?;
244 let request = request_builder.body(AsyncBody::from(serialized_request))?;
245
246 let mut response = client.send(request).await?;
247 if response.status().is_success() {
248 let mut body = Vec::new();
249 response.body_mut().read_to_end(&mut body).await?;
250 let response_message: ChatResponse = serde_json::from_slice(&body)?;
251 Ok(response_message)
252 } else {
253 let mut body = Vec::new();
254 response.body_mut().read_to_end(&mut body).await?;
255 let body_str = std::str::from_utf8(&body)?;
256 anyhow::bail!(
257 "Failed to connect to API: {} {}",
258 response.status(),
259 body_str
260 );
261 }
262}
263
264pub async fn stream_chat_completion(
265 client: &dyn HttpClient,
266 api_url: &str,
267 request: ChatCompletionRequest,
268) -> Result<BoxStream<'static, Result<ChatResponse>>> {
269 let uri = format!("{api_url}/chat/completions");
270 let request_builder = http::Request::builder()
271 .method(Method::POST)
272 .uri(uri)
273 .header("Content-Type", "application/json");
274
275 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
276 let mut response = client.send(request).await?;
277 if response.status().is_success() {
278 let reader = BufReader::new(response.into_body());
279
280 Ok(reader
281 .lines()
282 .filter_map(|line| async move {
283 match line {
284 Ok(line) => {
285 let line = line.strip_prefix("data: ")?;
286 if line == "[DONE]" {
287 None
288 } else {
289 let result = serde_json::from_str(&line)
290 .context("Unable to parse chat completions response");
291 if let Err(ref e) = result {
292 eprintln!("Error parsing line: {e}\nLine content: '{line}'");
293 }
294 Some(result)
295 }
296 }
297 Err(e) => {
298 eprintln!("Error reading line: {e}");
299 Some(Err(e.into()))
300 }
301 }
302 })
303 .boxed())
304 } else {
305 let mut body = String::new();
306 response.body_mut().read_to_string(&mut body).await?;
307 anyhow::bail!(
308 "Failed to connect to LM Studio API: {} {}",
309 response.status(),
310 body,
311 );
312 }
313}
314
315pub async fn get_models(
316 client: &dyn HttpClient,
317 api_url: &str,
318 _: Option<Duration>,
319) -> Result<Vec<ModelEntry>> {
320 let uri = format!("{api_url}/models");
321 let request_builder = HttpRequest::builder()
322 .method(Method::GET)
323 .uri(uri)
324 .header("Accept", "application/json");
325
326 let request = request_builder.body(AsyncBody::default())?;
327
328 let mut response = client.send(request).await?;
329
330 let mut body = String::new();
331 response.body_mut().read_to_string(&mut body).await?;
332
333 anyhow::ensure!(
334 response.status().is_success(),
335 "Failed to connect to LM Studio API: {} {}",
336 response.status(),
337 body,
338 );
339 let response: ListModelsResponse =
340 serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
341 Ok(response.data)
342}
343
344/// Sends an empty request to LM Studio to trigger loading the model
345pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
346 let uri = format!("{api_url}/completions");
347 let request = HttpRequest::builder()
348 .method(Method::POST)
349 .uri(uri)
350 .header("Content-Type", "application/json")
351 .body(AsyncBody::from(serde_json::to_string(
352 &serde_json::json!({
353 "model": model,
354 "messages": [],
355 "stream": false,
356 "max_tokens": 0,
357 }),
358 )?))?;
359
360 let mut response = client.send(request).await?;
361
362 if response.status().is_success() {
363 Ok(())
364 } else {
365 let mut body = String::new();
366 response.body_mut().read_to_string(&mut body).await?;
367 anyhow::bail!(
368 "Failed to connect to LM Studio API: {} {}",
369 response.status(),
370 body,
371 );
372 }
373}