1use anyhow::{Result, anyhow};
2use collections::HashMap;
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
10 StopReason, TokenUsage,
11};
12use language_model::{
13 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
14 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
15 LanguageModelRequest, RateLimiter, Role,
16};
17use lmstudio::{ModelType, get_models};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::pin::Pin;
22use std::str::FromStr;
23use std::{collections::BTreeMap, sync::Arc};
24use ui::{ButtonLike, Indicator, List, prelude::*};
25use util::ResultExt;
26
27use crate::AllLanguageModelSettings;
28use crate::ui::InstructionListItem;
29
30const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
31const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
32const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
33
34const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
35const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
36
37#[derive(Default, Debug, Clone, PartialEq)]
38pub struct LmStudioSettings {
39 pub api_url: String,
40 pub available_models: Vec<AvailableModel>,
41}
42
43#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
44pub struct AvailableModel {
45 pub name: String,
46 pub display_name: Option<String>,
47 pub max_tokens: u64,
48 pub supports_tool_calls: bool,
49 pub supports_images: bool,
50}
51
52pub struct LmStudioLanguageModelProvider {
53 http_client: Arc<dyn HttpClient>,
54 state: gpui::Entity<State>,
55}
56
57pub struct State {
58 http_client: Arc<dyn HttpClient>,
59 available_models: Vec<lmstudio::Model>,
60 fetch_model_task: Option<Task<Result<()>>>,
61 _subscription: Subscription,
62}
63
64impl State {
65 fn is_authenticated(&self) -> bool {
66 !self.available_models.is_empty()
67 }
68
69 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
70 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
71 let http_client = self.http_client.clone();
72 let api_url = settings.api_url.clone();
73
74 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
75 cx.spawn(async move |this, cx| {
76 let models = get_models(http_client.as_ref(), &api_url, None).await?;
77
78 let mut models: Vec<lmstudio::Model> = models
79 .into_iter()
80 .filter(|model| model.r#type != ModelType::Embeddings)
81 .map(|model| {
82 lmstudio::Model::new(
83 &model.id,
84 None,
85 model
86 .loaded_context_length
87 .or_else(|| model.max_context_length),
88 model.capabilities.supports_tool_calls(),
89 model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
90 )
91 })
92 .collect();
93
94 models.sort_by(|a, b| a.name.cmp(&b.name));
95
96 this.update(cx, |this, cx| {
97 this.available_models = models;
98 cx.notify();
99 })
100 })
101 }
102
103 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
104 let task = self.fetch_models(cx);
105 self.fetch_model_task.replace(task);
106 }
107
108 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
109 if self.is_authenticated() {
110 return Task::ready(Ok(()));
111 }
112
113 let fetch_models_task = self.fetch_models(cx);
114 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
115 }
116}
117
118impl LmStudioLanguageModelProvider {
119 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
120 let this = Self {
121 http_client: http_client.clone(),
122 state: cx.new(|cx| {
123 let subscription = cx.observe_global::<SettingsStore>({
124 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
125 move |this: &mut State, cx| {
126 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
127 if &settings != new_settings {
128 settings = new_settings.clone();
129 this.restart_fetch_models_task(cx);
130 cx.notify();
131 }
132 }
133 });
134
135 State {
136 http_client,
137 available_models: Default::default(),
138 fetch_model_task: None,
139 _subscription: subscription,
140 }
141 }),
142 };
143 this.state
144 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
145 this
146 }
147}
148
149impl LanguageModelProviderState for LmStudioLanguageModelProvider {
150 type ObservableEntity = State;
151
152 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
153 Some(self.state.clone())
154 }
155}
156
157impl LanguageModelProvider for LmStudioLanguageModelProvider {
158 fn id(&self) -> LanguageModelProviderId {
159 PROVIDER_ID
160 }
161
162 fn name(&self) -> LanguageModelProviderName {
163 PROVIDER_NAME
164 }
165
166 fn icon(&self) -> IconName {
167 IconName::AiLmStudio
168 }
169
170 fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
171 // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
172 // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
173 // to load by default.
174 None
175 }
176
177 fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
178 // See explanation for default_model.
179 None
180 }
181
182 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
183 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
184
185 // Add models from the LM Studio API
186 for model in self.state.read(cx).available_models.iter() {
187 models.insert(model.name.clone(), model.clone());
188 }
189
190 // Override with available models from settings
191 for model in AllLanguageModelSettings::get_global(cx)
192 .lmstudio
193 .available_models
194 .iter()
195 {
196 models.insert(
197 model.name.clone(),
198 lmstudio::Model {
199 name: model.name.clone(),
200 display_name: model.display_name.clone(),
201 max_tokens: model.max_tokens,
202 supports_tool_calls: model.supports_tool_calls,
203 supports_images: model.supports_images,
204 },
205 );
206 }
207
208 models
209 .into_values()
210 .map(|model| {
211 Arc::new(LmStudioLanguageModel {
212 id: LanguageModelId::from(model.name.clone()),
213 model,
214 http_client: self.http_client.clone(),
215 request_limiter: RateLimiter::new(4),
216 }) as Arc<dyn LanguageModel>
217 })
218 .collect()
219 }
220
221 fn is_authenticated(&self, cx: &App) -> bool {
222 self.state.read(cx).is_authenticated()
223 }
224
225 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
226 self.state.update(cx, |state, cx| state.authenticate(cx))
227 }
228
229 fn configuration_view(
230 &self,
231 _target_agent: language_model::ConfigurationViewTargetAgent,
232 _window: &mut Window,
233 cx: &mut App,
234 ) -> AnyView {
235 let state = self.state.clone();
236 cx.new(|cx| ConfigurationView::new(state, cx)).into()
237 }
238
239 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
240 self.state.update(cx, |state, cx| state.fetch_models(cx))
241 }
242}
243
244pub struct LmStudioLanguageModel {
245 id: LanguageModelId,
246 model: lmstudio::Model,
247 http_client: Arc<dyn HttpClient>,
248 request_limiter: RateLimiter,
249}
250
251impl LmStudioLanguageModel {
252 fn to_lmstudio_request(
253 &self,
254 request: LanguageModelRequest,
255 ) -> lmstudio::ChatCompletionRequest {
256 let mut messages = Vec::new();
257
258 for message in request.messages {
259 for content in message.content {
260 match content {
261 MessageContent::Text(text) => add_message_content_part(
262 lmstudio::MessagePart::Text { text },
263 message.role,
264 &mut messages,
265 ),
266 MessageContent::Thinking { .. } => {}
267 MessageContent::RedactedThinking(_) => {}
268 MessageContent::Image(image) => {
269 add_message_content_part(
270 lmstudio::MessagePart::Image {
271 image_url: lmstudio::ImageUrl {
272 url: image.to_base64_url(),
273 detail: None,
274 },
275 },
276 message.role,
277 &mut messages,
278 );
279 }
280 MessageContent::ToolUse(tool_use) => {
281 let tool_call = lmstudio::ToolCall {
282 id: tool_use.id.to_string(),
283 content: lmstudio::ToolCallContent::Function {
284 function: lmstudio::FunctionContent {
285 name: tool_use.name.to_string(),
286 arguments: serde_json::to_string(&tool_use.input)
287 .unwrap_or_default(),
288 },
289 },
290 };
291
292 if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
293 messages.last_mut()
294 {
295 tool_calls.push(tool_call);
296 } else {
297 messages.push(lmstudio::ChatMessage::Assistant {
298 content: None,
299 tool_calls: vec![tool_call],
300 });
301 }
302 }
303 MessageContent::ToolResult(tool_result) => {
304 let content = match &tool_result.content {
305 LanguageModelToolResultContent::Text(text) => {
306 vec![lmstudio::MessagePart::Text {
307 text: text.to_string(),
308 }]
309 }
310 LanguageModelToolResultContent::Image(image) => {
311 vec![lmstudio::MessagePart::Image {
312 image_url: lmstudio::ImageUrl {
313 url: image.to_base64_url(),
314 detail: None,
315 },
316 }]
317 }
318 };
319
320 messages.push(lmstudio::ChatMessage::Tool {
321 content: content.into(),
322 tool_call_id: tool_result.tool_use_id.to_string(),
323 });
324 }
325 }
326 }
327 }
328
329 lmstudio::ChatCompletionRequest {
330 model: self.model.name.clone(),
331 messages,
332 stream: true,
333 max_tokens: Some(-1),
334 stop: Some(request.stop),
335 // In LM Studio you can configure specific settings you'd like to use for your model.
336 // For example Qwen3 is recommended to be used with 0.7 temperature.
337 // It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
338 temperature: request.temperature.or(None),
339 tools: request
340 .tools
341 .into_iter()
342 .map(|tool| lmstudio::ToolDefinition::Function {
343 function: lmstudio::FunctionDefinition {
344 name: tool.name,
345 description: Some(tool.description),
346 parameters: Some(tool.input_schema),
347 },
348 })
349 .collect(),
350 tool_choice: request.tool_choice.map(|choice| match choice {
351 LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
352 LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
353 LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
354 }),
355 }
356 }
357
358 fn stream_completion(
359 &self,
360 request: lmstudio::ChatCompletionRequest,
361 cx: &AsyncApp,
362 ) -> BoxFuture<
363 'static,
364 Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
365 > {
366 let http_client = self.http_client.clone();
367 let Ok(api_url) = cx.update(|cx| {
368 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
369 settings.api_url.clone()
370 }) else {
371 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
372 };
373
374 let future = self.request_limiter.stream(async move {
375 let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
376 let response = request.await?;
377 Ok(response)
378 });
379
380 async move { Ok(future.await?.boxed()) }.boxed()
381 }
382}
383
384impl LanguageModel for LmStudioLanguageModel {
385 fn id(&self) -> LanguageModelId {
386 self.id.clone()
387 }
388
389 fn name(&self) -> LanguageModelName {
390 LanguageModelName::from(self.model.display_name().to_string())
391 }
392
393 fn provider_id(&self) -> LanguageModelProviderId {
394 PROVIDER_ID
395 }
396
397 fn provider_name(&self) -> LanguageModelProviderName {
398 PROVIDER_NAME
399 }
400
401 fn supports_tools(&self) -> bool {
402 self.model.supports_tool_calls()
403 }
404
405 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
406 self.supports_tools()
407 && match choice {
408 LanguageModelToolChoice::Auto => true,
409 LanguageModelToolChoice::Any => true,
410 LanguageModelToolChoice::None => true,
411 }
412 }
413
414 fn supports_images(&self) -> bool {
415 self.model.supports_images
416 }
417
418 fn telemetry_id(&self) -> String {
419 format!("lmstudio/{}", self.model.id())
420 }
421
422 fn max_token_count(&self) -> u64 {
423 self.model.max_token_count()
424 }
425
426 fn count_tokens(
427 &self,
428 request: LanguageModelRequest,
429 _cx: &App,
430 ) -> BoxFuture<'static, Result<u64>> {
431 // Endpoint for this is coming soon. In the meantime, hacky estimation
432 let token_count = request
433 .messages
434 .iter()
435 .map(|msg| msg.string_contents().split_whitespace().count())
436 .sum::<usize>();
437
438 let estimated_tokens = (token_count as f64 * 0.75) as u64;
439 async move { Ok(estimated_tokens) }.boxed()
440 }
441
442 fn stream_completion(
443 &self,
444 request: LanguageModelRequest,
445 cx: &AsyncApp,
446 ) -> BoxFuture<
447 'static,
448 Result<
449 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
450 LanguageModelCompletionError,
451 >,
452 > {
453 let request = self.to_lmstudio_request(request);
454 let completions = self.stream_completion(request, cx);
455 async move {
456 let mapper = LmStudioEventMapper::new();
457 Ok(mapper.map_stream(completions.await?).boxed())
458 }
459 .boxed()
460 }
461}
462
463struct LmStudioEventMapper {
464 tool_calls_by_index: HashMap<usize, RawToolCall>,
465}
466
467impl LmStudioEventMapper {
468 fn new() -> Self {
469 Self {
470 tool_calls_by_index: HashMap::default(),
471 }
472 }
473
474 pub fn map_stream(
475 mut self,
476 events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
477 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
478 {
479 events.flat_map(move |event| {
480 futures::stream::iter(match event {
481 Ok(event) => self.map_event(event),
482 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
483 })
484 })
485 }
486
487 pub fn map_event(
488 &mut self,
489 event: lmstudio::ResponseStreamEvent,
490 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
491 let Some(choice) = event.choices.into_iter().next() else {
492 return vec![Err(LanguageModelCompletionError::from(anyhow!(
493 "Response contained no choices"
494 )))];
495 };
496
497 let mut events = Vec::new();
498 if let Some(content) = choice.delta.content {
499 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
500 }
501
502 if let Some(reasoning_content) = choice.delta.reasoning_content {
503 events.push(Ok(LanguageModelCompletionEvent::Thinking {
504 text: reasoning_content,
505 signature: None,
506 }));
507 }
508
509 if let Some(tool_calls) = choice.delta.tool_calls {
510 for tool_call in tool_calls {
511 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
512
513 if let Some(tool_id) = tool_call.id {
514 entry.id = tool_id;
515 }
516
517 if let Some(function) = tool_call.function {
518 if let Some(name) = function.name {
519 // At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
520 // 1. It sends function name in the first chunk
521 // 2. It sends empty string in the function name field in all subsequent chunks for arguments
522 // According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
523 // function name field should be sent only inside the first chunk.
524 if !name.is_empty() {
525 entry.name = name;
526 }
527 }
528
529 if let Some(arguments) = function.arguments {
530 entry.arguments.push_str(&arguments);
531 }
532 }
533 }
534 }
535
536 if let Some(usage) = event.usage {
537 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
538 input_tokens: usage.prompt_tokens,
539 output_tokens: usage.completion_tokens,
540 cache_creation_input_tokens: 0,
541 cache_read_input_tokens: 0,
542 })));
543 }
544
545 match choice.finish_reason.as_deref() {
546 Some("stop") => {
547 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
548 }
549 Some("tool_calls") => {
550 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
551 match serde_json::Value::from_str(&tool_call.arguments) {
552 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
553 LanguageModelToolUse {
554 id: tool_call.id.into(),
555 name: tool_call.name.into(),
556 is_input_complete: true,
557 input,
558 raw_input: tool_call.arguments,
559 },
560 )),
561 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
562 id: tool_call.id.into(),
563 tool_name: tool_call.name.into(),
564 raw_input: tool_call.arguments.into(),
565 json_parse_error: error.to_string(),
566 }),
567 }
568 }));
569
570 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
571 }
572 Some(stop_reason) => {
573 log::error!("Unexpected LMStudio stop_reason: {stop_reason:?}",);
574 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
575 }
576 None => {}
577 }
578
579 events
580 }
581}
582
583#[derive(Default)]
584struct RawToolCall {
585 id: String,
586 name: String,
587 arguments: String,
588}
589
590fn add_message_content_part(
591 new_part: lmstudio::MessagePart,
592 role: Role,
593 messages: &mut Vec<lmstudio::ChatMessage>,
594) {
595 match (role, messages.last_mut()) {
596 (Role::User, Some(lmstudio::ChatMessage::User { content }))
597 | (
598 Role::Assistant,
599 Some(lmstudio::ChatMessage::Assistant {
600 content: Some(content),
601 ..
602 }),
603 )
604 | (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
605 content.push_part(new_part);
606 }
607 _ => {
608 messages.push(match role {
609 Role::User => lmstudio::ChatMessage::User {
610 content: lmstudio::MessageContent::from(vec![new_part]),
611 },
612 Role::Assistant => lmstudio::ChatMessage::Assistant {
613 content: Some(lmstudio::MessageContent::from(vec![new_part])),
614 tool_calls: Vec::new(),
615 },
616 Role::System => lmstudio::ChatMessage::System {
617 content: lmstudio::MessageContent::from(vec![new_part]),
618 },
619 });
620 }
621 }
622}
623
624struct ConfigurationView {
625 state: gpui::Entity<State>,
626 loading_models_task: Option<Task<()>>,
627}
628
629impl ConfigurationView {
630 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
631 let loading_models_task = Some(cx.spawn({
632 let state = state.clone();
633 async move |this, cx| {
634 if let Some(task) = state
635 .update(cx, |state, cx| state.authenticate(cx))
636 .log_err()
637 {
638 task.await.log_err();
639 }
640 this.update(cx, |this, cx| {
641 this.loading_models_task = None;
642 cx.notify();
643 })
644 .log_err();
645 }
646 }));
647
648 Self {
649 state,
650 loading_models_task,
651 }
652 }
653
654 fn retry_connection(&self, cx: &mut App) {
655 self.state
656 .update(cx, |state, cx| state.fetch_models(cx))
657 .detach_and_log_err(cx);
658 }
659}
660
661impl Render for ConfigurationView {
662 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
663 let is_authenticated = self.state.read(cx).is_authenticated();
664
665 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
666
667 if self.loading_models_task.is_some() {
668 div().child(Label::new("Loading models...")).into_any()
669 } else {
670 v_flex()
671 .gap_2()
672 .child(
673 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
674 List::new()
675 .child(InstructionListItem::text_only(
676 "LM Studio needs to be running with at least one model downloaded.",
677 ))
678 .child(InstructionListItem::text_only(
679 "To get your first model, try running `lms get qwen2.5-coder-7b`",
680 )),
681 ),
682 )
683 .child(
684 h_flex()
685 .w_full()
686 .justify_between()
687 .gap_2()
688 .child(
689 h_flex()
690 .w_full()
691 .gap_2()
692 .map(|this| {
693 if is_authenticated {
694 this.child(
695 Button::new("lmstudio-site", "LM Studio")
696 .style(ButtonStyle::Subtle)
697 .icon(IconName::ArrowUpRight)
698 .icon_size(IconSize::Small)
699 .icon_color(Color::Muted)
700 .on_click(move |_, _window, cx| {
701 cx.open_url(LMSTUDIO_SITE)
702 })
703 .into_any_element(),
704 )
705 } else {
706 this.child(
707 Button::new(
708 "download_lmstudio_button",
709 "Download LM Studio",
710 )
711 .style(ButtonStyle::Subtle)
712 .icon(IconName::ArrowUpRight)
713 .icon_size(IconSize::Small)
714 .icon_color(Color::Muted)
715 .on_click(move |_, _window, cx| {
716 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
717 })
718 .into_any_element(),
719 )
720 }
721 })
722 .child(
723 Button::new("view-models", "Model Catalog")
724 .style(ButtonStyle::Subtle)
725 .icon(IconName::ArrowUpRight)
726 .icon_size(IconSize::Small)
727 .icon_color(Color::Muted)
728 .on_click(move |_, _window, cx| {
729 cx.open_url(LMSTUDIO_CATALOG_URL)
730 }),
731 ),
732 )
733 .map(|this| {
734 if is_authenticated {
735 this.child(
736 ButtonLike::new("connected")
737 .disabled(true)
738 .cursor_style(gpui::CursorStyle::Arrow)
739 .child(
740 h_flex()
741 .gap_2()
742 .child(Indicator::dot().color(Color::Success))
743 .child(Label::new("Connected"))
744 .into_any_element(),
745 ),
746 )
747 } else {
748 this.child(
749 Button::new("retry_lmstudio_models", "Connect")
750 .icon_position(IconPosition::Start)
751 .icon_size(IconSize::XSmall)
752 .icon(IconName::PlayFilled)
753 .on_click(cx.listener(move |this, _, _window, cx| {
754 this.retry_connection(cx)
755 })),
756 )
757 }
758 }),
759 )
760 .into_any()
761 }
762 }
763}