1use anyhow::{Result, anyhow};
2use collections::HashMap;
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
10 StopReason, WrappedTextContent,
11};
12use language_model::{
13 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
14 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
15 LanguageModelRequest, RateLimiter, Role,
16};
17use lmstudio::{
18 ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model,
19 stream_chat_completion,
20};
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use settings::{Settings, SettingsStore};
24use std::pin::Pin;
25use std::str::FromStr;
26use std::{collections::BTreeMap, sync::Arc};
27use ui::{ButtonLike, Indicator, List, prelude::*};
28use util::ResultExt;
29
30use crate::AllLanguageModelSettings;
31use crate::ui::InstructionListItem;
32
33const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
34const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
35const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
36
37const PROVIDER_ID: &str = "lmstudio";
38const PROVIDER_NAME: &str = "LM Studio";
39
40#[derive(Default, Debug, Clone, PartialEq)]
41pub struct LmStudioSettings {
42 pub api_url: String,
43 pub available_models: Vec<AvailableModel>,
44}
45
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
47pub struct AvailableModel {
48 pub name: String,
49 pub display_name: Option<String>,
50 pub max_tokens: usize,
51 pub supports_tool_calls: bool,
52}
53
54pub struct LmStudioLanguageModelProvider {
55 http_client: Arc<dyn HttpClient>,
56 state: gpui::Entity<State>,
57}
58
59pub struct State {
60 http_client: Arc<dyn HttpClient>,
61 available_models: Vec<lmstudio::Model>,
62 fetch_model_task: Option<Task<Result<()>>>,
63 _subscription: Subscription,
64}
65
66impl State {
67 fn is_authenticated(&self) -> bool {
68 !self.available_models.is_empty()
69 }
70
71 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
72 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
73 let http_client = self.http_client.clone();
74 let api_url = settings.api_url.clone();
75
76 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
77 cx.spawn(async move |this, cx| {
78 let models = get_models(http_client.as_ref(), &api_url, None).await?;
79
80 let mut models: Vec<lmstudio::Model> = models
81 .into_iter()
82 .filter(|model| model.r#type != ModelType::Embeddings)
83 .map(|model| {
84 lmstudio::Model::new(
85 &model.id,
86 None,
87 None,
88 model.capabilities.supports_tool_calls(),
89 )
90 })
91 .collect();
92
93 models.sort_by(|a, b| a.name.cmp(&b.name));
94
95 this.update(cx, |this, cx| {
96 this.available_models = models;
97 cx.notify();
98 })
99 })
100 }
101
102 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
103 let task = self.fetch_models(cx);
104 self.fetch_model_task.replace(task);
105 }
106
107 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
108 if self.is_authenticated() {
109 return Task::ready(Ok(()));
110 }
111
112 let fetch_models_task = self.fetch_models(cx);
113 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
114 }
115}
116
117impl LmStudioLanguageModelProvider {
118 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
119 let this = Self {
120 http_client: http_client.clone(),
121 state: cx.new(|cx| {
122 let subscription = cx.observe_global::<SettingsStore>({
123 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
124 move |this: &mut State, cx| {
125 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
126 if &settings != new_settings {
127 settings = new_settings.clone();
128 this.restart_fetch_models_task(cx);
129 cx.notify();
130 }
131 }
132 });
133
134 State {
135 http_client,
136 available_models: Default::default(),
137 fetch_model_task: None,
138 _subscription: subscription,
139 }
140 }),
141 };
142 this.state
143 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
144 this
145 }
146}
147
148impl LanguageModelProviderState for LmStudioLanguageModelProvider {
149 type ObservableEntity = State;
150
151 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
152 Some(self.state.clone())
153 }
154}
155
156impl LanguageModelProvider for LmStudioLanguageModelProvider {
157 fn id(&self) -> LanguageModelProviderId {
158 LanguageModelProviderId(PROVIDER_ID.into())
159 }
160
161 fn name(&self) -> LanguageModelProviderName {
162 LanguageModelProviderName(PROVIDER_NAME.into())
163 }
164
165 fn icon(&self) -> IconName {
166 IconName::AiLmStudio
167 }
168
169 fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
170 // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
171 // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
172 // to load by default.
173 None
174 }
175
176 fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
177 // See explanation for default_model.
178 None
179 }
180
181 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
182 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
183
184 // Add models from the LM Studio API
185 for model in self.state.read(cx).available_models.iter() {
186 models.insert(model.name.clone(), model.clone());
187 }
188
189 // Override with available models from settings
190 for model in AllLanguageModelSettings::get_global(cx)
191 .lmstudio
192 .available_models
193 .iter()
194 {
195 models.insert(
196 model.name.clone(),
197 lmstudio::Model {
198 name: model.name.clone(),
199 display_name: model.display_name.clone(),
200 max_tokens: model.max_tokens,
201 supports_tool_calls: model.supports_tool_calls,
202 },
203 );
204 }
205
206 models
207 .into_values()
208 .map(|model| {
209 Arc::new(LmStudioLanguageModel {
210 id: LanguageModelId::from(model.name.clone()),
211 model: model.clone(),
212 http_client: self.http_client.clone(),
213 request_limiter: RateLimiter::new(4),
214 }) as Arc<dyn LanguageModel>
215 })
216 .collect()
217 }
218
219 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
220 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
221 let http_client = self.http_client.clone();
222 let api_url = settings.api_url.clone();
223 let id = model.id().0.to_string();
224 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
225 .detach_and_log_err(cx);
226 }
227
228 fn is_authenticated(&self, cx: &App) -> bool {
229 self.state.read(cx).is_authenticated()
230 }
231
232 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
233 self.state.update(cx, |state, cx| state.authenticate(cx))
234 }
235
236 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
237 let state = self.state.clone();
238 cx.new(|cx| ConfigurationView::new(state, cx)).into()
239 }
240
241 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
242 self.state.update(cx, |state, cx| state.fetch_models(cx))
243 }
244}
245
246pub struct LmStudioLanguageModel {
247 id: LanguageModelId,
248 model: lmstudio::Model,
249 http_client: Arc<dyn HttpClient>,
250 request_limiter: RateLimiter,
251}
252
253impl LmStudioLanguageModel {
254 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
255 let mut messages = Vec::new();
256
257 for message in request.messages {
258 for content in message.content {
259 match content {
260 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
261 .push(match message.role {
262 Role::User => ChatMessage::User { content: text },
263 Role::Assistant => ChatMessage::Assistant {
264 content: Some(text),
265 tool_calls: Vec::new(),
266 },
267 Role::System => ChatMessage::System { content: text },
268 }),
269 MessageContent::RedactedThinking(_) => {}
270 MessageContent::Image(_) => {}
271 MessageContent::ToolUse(tool_use) => {
272 let tool_call = lmstudio::ToolCall {
273 id: tool_use.id.to_string(),
274 content: lmstudio::ToolCallContent::Function {
275 function: lmstudio::FunctionContent {
276 name: tool_use.name.to_string(),
277 arguments: serde_json::to_string(&tool_use.input)
278 .unwrap_or_default(),
279 },
280 },
281 };
282
283 if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
284 messages.last_mut()
285 {
286 tool_calls.push(tool_call);
287 } else {
288 messages.push(lmstudio::ChatMessage::Assistant {
289 content: None,
290 tool_calls: vec![tool_call],
291 });
292 }
293 }
294 MessageContent::ToolResult(tool_result) => {
295 match &tool_result.content {
296 LanguageModelToolResultContent::Text(text)
297 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
298 text,
299 ..
300 }) => {
301 messages.push(lmstudio::ChatMessage::Tool {
302 content: text.to_string(),
303 tool_call_id: tool_result.tool_use_id.to_string(),
304 });
305 }
306 LanguageModelToolResultContent::Image(_) => {
307 // no support for images for now
308 }
309 };
310 }
311 }
312 }
313 }
314
315 ChatCompletionRequest {
316 model: self.model.name.clone(),
317 messages,
318 stream: true,
319 max_tokens: Some(-1),
320 stop: Some(request.stop),
321 // In LM Studio you can configure specific settings you'd like to use for your model.
322 // For example Qwen3 is recommended to be used with 0.7 temperature.
323 // It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
324 temperature: request.temperature.or(None),
325 tools: request
326 .tools
327 .into_iter()
328 .map(|tool| lmstudio::ToolDefinition::Function {
329 function: lmstudio::FunctionDefinition {
330 name: tool.name,
331 description: Some(tool.description),
332 parameters: Some(tool.input_schema),
333 },
334 })
335 .collect(),
336 tool_choice: request.tool_choice.map(|choice| match choice {
337 LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
338 LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
339 LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
340 }),
341 }
342 }
343
344 fn stream_completion(
345 &self,
346 request: ChatCompletionRequest,
347 cx: &AsyncApp,
348 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
349 {
350 let http_client = self.http_client.clone();
351 let Ok(api_url) = cx.update(|cx| {
352 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
353 settings.api_url.clone()
354 }) else {
355 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
356 };
357
358 let future = self.request_limiter.stream(async move {
359 let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
360 let response = request.await?;
361 Ok(response)
362 });
363
364 async move { Ok(future.await?.boxed()) }.boxed()
365 }
366}
367
368impl LanguageModel for LmStudioLanguageModel {
369 fn id(&self) -> LanguageModelId {
370 self.id.clone()
371 }
372
373 fn name(&self) -> LanguageModelName {
374 LanguageModelName::from(self.model.display_name().to_string())
375 }
376
377 fn provider_id(&self) -> LanguageModelProviderId {
378 LanguageModelProviderId(PROVIDER_ID.into())
379 }
380
381 fn provider_name(&self) -> LanguageModelProviderName {
382 LanguageModelProviderName(PROVIDER_NAME.into())
383 }
384
385 fn supports_tools(&self) -> bool {
386 self.model.supports_tool_calls()
387 }
388
389 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
390 self.supports_tools()
391 && match choice {
392 LanguageModelToolChoice::Auto => true,
393 LanguageModelToolChoice::Any => true,
394 LanguageModelToolChoice::None => true,
395 }
396 }
397
398 fn supports_images(&self) -> bool {
399 false
400 }
401
402 fn telemetry_id(&self) -> String {
403 format!("lmstudio/{}", self.model.id())
404 }
405
406 fn max_token_count(&self) -> usize {
407 self.model.max_token_count()
408 }
409
410 fn count_tokens(
411 &self,
412 request: LanguageModelRequest,
413 _cx: &App,
414 ) -> BoxFuture<'static, Result<usize>> {
415 // Endpoint for this is coming soon. In the meantime, hacky estimation
416 let token_count = request
417 .messages
418 .iter()
419 .map(|msg| msg.string_contents().split_whitespace().count())
420 .sum::<usize>();
421
422 let estimated_tokens = (token_count as f64 * 0.75) as usize;
423 async move { Ok(estimated_tokens) }.boxed()
424 }
425
426 fn stream_completion(
427 &self,
428 request: LanguageModelRequest,
429 cx: &AsyncApp,
430 ) -> BoxFuture<
431 'static,
432 Result<
433 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
434 >,
435 > {
436 let request = self.to_lmstudio_request(request);
437 let completions = self.stream_completion(request, cx);
438 async move {
439 let mapper = LmStudioEventMapper::new();
440 Ok(mapper.map_stream(completions.await?).boxed())
441 }
442 .boxed()
443 }
444}
445
446struct LmStudioEventMapper {
447 tool_calls_by_index: HashMap<usize, RawToolCall>,
448}
449
450impl LmStudioEventMapper {
451 fn new() -> Self {
452 Self {
453 tool_calls_by_index: HashMap::default(),
454 }
455 }
456
457 pub fn map_stream(
458 mut self,
459 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
460 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
461 {
462 events.flat_map(move |event| {
463 futures::stream::iter(match event {
464 Ok(event) => self.map_event(event),
465 Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
466 })
467 })
468 }
469
470 pub fn map_event(
471 &mut self,
472 event: ResponseStreamEvent,
473 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
474 let Some(choice) = event.choices.into_iter().next() else {
475 return vec![Err(LanguageModelCompletionError::Other(anyhow!(
476 "Response contained no choices"
477 )))];
478 };
479
480 let mut events = Vec::new();
481 if let Some(content) = choice.delta.content {
482 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
483 }
484
485 if let Some(tool_calls) = choice.delta.tool_calls {
486 for tool_call in tool_calls {
487 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
488
489 if let Some(tool_id) = tool_call.id {
490 entry.id = tool_id;
491 }
492
493 if let Some(function) = tool_call.function {
494 if let Some(name) = function.name {
495 // At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
496 // 1. It sends function name in the first chunk
497 // 2. It sends empty string in the function name field in all subsequent chunks for arguments
498 // According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
499 // function name field should be sent only inside the first chunk.
500 if !name.is_empty() {
501 entry.name = name;
502 }
503 }
504
505 if let Some(arguments) = function.arguments {
506 entry.arguments.push_str(&arguments);
507 }
508 }
509 }
510 }
511
512 match choice.finish_reason.as_deref() {
513 Some("stop") => {
514 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
515 }
516 Some("tool_calls") => {
517 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
518 match serde_json::Value::from_str(&tool_call.arguments) {
519 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
520 LanguageModelToolUse {
521 id: tool_call.id.into(),
522 name: tool_call.name.into(),
523 is_input_complete: true,
524 input,
525 raw_input: tool_call.arguments,
526 },
527 )),
528 Err(error) => Err(LanguageModelCompletionError::BadInputJson {
529 id: tool_call.id.into(),
530 tool_name: tool_call.name.into(),
531 raw_input: tool_call.arguments.into(),
532 json_parse_error: error.to_string(),
533 }),
534 }
535 }));
536
537 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
538 }
539 Some(stop_reason) => {
540 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
541 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
542 }
543 None => {}
544 }
545
546 events
547 }
548}
549
550#[derive(Default)]
551struct RawToolCall {
552 id: String,
553 name: String,
554 arguments: String,
555}
556
557struct ConfigurationView {
558 state: gpui::Entity<State>,
559 loading_models_task: Option<Task<()>>,
560}
561
562impl ConfigurationView {
563 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
564 let loading_models_task = Some(cx.spawn({
565 let state = state.clone();
566 async move |this, cx| {
567 if let Some(task) = state
568 .update(cx, |state, cx| state.authenticate(cx))
569 .log_err()
570 {
571 task.await.log_err();
572 }
573 this.update(cx, |this, cx| {
574 this.loading_models_task = None;
575 cx.notify();
576 })
577 .log_err();
578 }
579 }));
580
581 Self {
582 state,
583 loading_models_task,
584 }
585 }
586
587 fn retry_connection(&self, cx: &mut App) {
588 self.state
589 .update(cx, |state, cx| state.fetch_models(cx))
590 .detach_and_log_err(cx);
591 }
592}
593
594impl Render for ConfigurationView {
595 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
596 let is_authenticated = self.state.read(cx).is_authenticated();
597
598 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
599
600 if self.loading_models_task.is_some() {
601 div().child(Label::new("Loading models...")).into_any()
602 } else {
603 v_flex()
604 .gap_2()
605 .child(
606 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
607 List::new()
608 .child(InstructionListItem::text_only(
609 "LM Studio needs to be running with at least one model downloaded.",
610 ))
611 .child(InstructionListItem::text_only(
612 "To get your first model, try running `lms get qwen2.5-coder-7b`",
613 )),
614 ),
615 )
616 .child(
617 h_flex()
618 .w_full()
619 .justify_between()
620 .gap_2()
621 .child(
622 h_flex()
623 .w_full()
624 .gap_2()
625 .map(|this| {
626 if is_authenticated {
627 this.child(
628 Button::new("lmstudio-site", "LM Studio")
629 .style(ButtonStyle::Subtle)
630 .icon(IconName::ArrowUpRight)
631 .icon_size(IconSize::XSmall)
632 .icon_color(Color::Muted)
633 .on_click(move |_, _window, cx| {
634 cx.open_url(LMSTUDIO_SITE)
635 })
636 .into_any_element(),
637 )
638 } else {
639 this.child(
640 Button::new(
641 "download_lmstudio_button",
642 "Download LM Studio",
643 )
644 .style(ButtonStyle::Subtle)
645 .icon(IconName::ArrowUpRight)
646 .icon_size(IconSize::XSmall)
647 .icon_color(Color::Muted)
648 .on_click(move |_, _window, cx| {
649 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
650 })
651 .into_any_element(),
652 )
653 }
654 })
655 .child(
656 Button::new("view-models", "Model Catalog")
657 .style(ButtonStyle::Subtle)
658 .icon(IconName::ArrowUpRight)
659 .icon_size(IconSize::XSmall)
660 .icon_color(Color::Muted)
661 .on_click(move |_, _window, cx| {
662 cx.open_url(LMSTUDIO_CATALOG_URL)
663 }),
664 ),
665 )
666 .map(|this| {
667 if is_authenticated {
668 this.child(
669 ButtonLike::new("connected")
670 .disabled(true)
671 .cursor_style(gpui::CursorStyle::Arrow)
672 .child(
673 h_flex()
674 .gap_2()
675 .child(Indicator::dot().color(Color::Success))
676 .child(Label::new("Connected"))
677 .into_any_element(),
678 ),
679 )
680 } else {
681 this.child(
682 Button::new("retry_lmstudio_models", "Connect")
683 .icon_position(IconPosition::Start)
684 .icon_size(IconSize::XSmall)
685 .icon(IconName::Play)
686 .on_click(cx.listener(move |this, _, _window, cx| {
687 this.retry_connection(cx)
688 })),
689 )
690 }
691 }),
692 )
693 .into_any()
694 }
695 }
696}