1use anyhow::{Result, anyhow};
2use collections::HashMap;
3use futures::Stream;
4use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
10 StopReason,
11};
12use language_model::{
13 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
14 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
15 LanguageModelRequest, RateLimiter, Role,
16};
17use lmstudio::{
18 ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
19 stream_chat_completion,
20};
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use settings::{Settings, SettingsStore};
24use std::pin::Pin;
25use std::str::FromStr;
26use std::{collections::BTreeMap, sync::Arc};
27use ui::{ButtonLike, Indicator, List, prelude::*};
28use util::ResultExt;
29
30use crate::AllLanguageModelSettings;
31use crate::ui::InstructionListItem;
32
33const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
34const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
35const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
36
37const PROVIDER_ID: &str = "lmstudio";
38const PROVIDER_NAME: &str = "LM Studio";
39
40#[derive(Default, Debug, Clone, PartialEq)]
41pub struct LmStudioSettings {
42 pub api_url: String,
43 pub available_models: Vec<AvailableModel>,
44}
45
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
47pub struct AvailableModel {
48 pub name: String,
49 pub display_name: Option<String>,
50 pub max_tokens: usize,
51 pub supports_tool_calls: bool,
52}
53
54pub struct LmStudioLanguageModelProvider {
55 http_client: Arc<dyn HttpClient>,
56 state: gpui::Entity<State>,
57}
58
59pub struct State {
60 http_client: Arc<dyn HttpClient>,
61 available_models: Vec<lmstudio::Model>,
62 fetch_model_task: Option<Task<Result<()>>>,
63 _subscription: Subscription,
64}
65
66impl State {
67 fn is_authenticated(&self) -> bool {
68 !self.available_models.is_empty()
69 }
70
71 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
72 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
73 let http_client = self.http_client.clone();
74 let api_url = settings.api_url.clone();
75
76 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
77 cx.spawn(async move |this, cx| {
78 let models = get_models(http_client.as_ref(), &api_url, None).await?;
79
80 let mut models: Vec<lmstudio::Model> = models
81 .into_iter()
82 .filter(|model| model.r#type != ModelType::Embeddings)
83 .map(|model| {
84 lmstudio::Model::new(
85 &model.id,
86 None,
87 model
88 .loaded_context_length
89 .or_else(|| model.max_context_length),
90 model.capabilities.supports_tool_calls(),
91 )
92 })
93 .collect();
94
95 models.sort_by(|a, b| a.name.cmp(&b.name));
96
97 this.update(cx, |this, cx| {
98 this.available_models = models;
99 cx.notify();
100 })
101 })
102 }
103
104 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
105 let task = self.fetch_models(cx);
106 self.fetch_model_task.replace(task);
107 }
108
109 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
110 if self.is_authenticated() {
111 return Task::ready(Ok(()));
112 }
113
114 let fetch_models_task = self.fetch_models(cx);
115 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
116 }
117}
118
119impl LmStudioLanguageModelProvider {
120 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
121 let this = Self {
122 http_client: http_client.clone(),
123 state: cx.new(|cx| {
124 let subscription = cx.observe_global::<SettingsStore>({
125 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
126 move |this: &mut State, cx| {
127 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
128 if &settings != new_settings {
129 settings = new_settings.clone();
130 this.restart_fetch_models_task(cx);
131 cx.notify();
132 }
133 }
134 });
135
136 State {
137 http_client,
138 available_models: Default::default(),
139 fetch_model_task: None,
140 _subscription: subscription,
141 }
142 }),
143 };
144 this.state
145 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
146 this
147 }
148}
149
150impl LanguageModelProviderState for LmStudioLanguageModelProvider {
151 type ObservableEntity = State;
152
153 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
154 Some(self.state.clone())
155 }
156}
157
158impl LanguageModelProvider for LmStudioLanguageModelProvider {
159 fn id(&self) -> LanguageModelProviderId {
160 LanguageModelProviderId(PROVIDER_ID.into())
161 }
162
163 fn name(&self) -> LanguageModelProviderName {
164 LanguageModelProviderName(PROVIDER_NAME.into())
165 }
166
167 fn icon(&self) -> IconName {
168 IconName::AiLmStudio
169 }
170
171 fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
172 // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
173 // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
174 // to load by default.
175 None
176 }
177
178 fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
179 // See explanation for default_model.
180 None
181 }
182
183 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
184 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
185
186 // Add models from the LM Studio API
187 for model in self.state.read(cx).available_models.iter() {
188 models.insert(model.name.clone(), model.clone());
189 }
190
191 // Override with available models from settings
192 for model in AllLanguageModelSettings::get_global(cx)
193 .lmstudio
194 .available_models
195 .iter()
196 {
197 models.insert(
198 model.name.clone(),
199 lmstudio::Model {
200 name: model.name.clone(),
201 display_name: model.display_name.clone(),
202 max_tokens: model.max_tokens,
203 supports_tool_calls: model.supports_tool_calls,
204 },
205 );
206 }
207
208 models
209 .into_values()
210 .map(|model| {
211 Arc::new(LmStudioLanguageModel {
212 id: LanguageModelId::from(model.name.clone()),
213 model: model.clone(),
214 http_client: self.http_client.clone(),
215 request_limiter: RateLimiter::new(4),
216 }) as Arc<dyn LanguageModel>
217 })
218 .collect()
219 }
220
221 fn is_authenticated(&self, cx: &App) -> bool {
222 self.state.read(cx).is_authenticated()
223 }
224
225 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
226 self.state.update(cx, |state, cx| state.authenticate(cx))
227 }
228
229 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
230 let state = self.state.clone();
231 cx.new(|cx| ConfigurationView::new(state, cx)).into()
232 }
233
234 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
235 self.state.update(cx, |state, cx| state.fetch_models(cx))
236 }
237}
238
239pub struct LmStudioLanguageModel {
240 id: LanguageModelId,
241 model: lmstudio::Model,
242 http_client: Arc<dyn HttpClient>,
243 request_limiter: RateLimiter,
244}
245
246impl LmStudioLanguageModel {
247 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
248 let mut messages = Vec::new();
249
250 for message in request.messages {
251 for content in message.content {
252 match content {
253 MessageContent::Text(text) => messages.push(match message.role {
254 Role::User => ChatMessage::User { content: text },
255 Role::Assistant => ChatMessage::Assistant {
256 content: Some(text),
257 tool_calls: Vec::new(),
258 },
259 Role::System => ChatMessage::System { content: text },
260 }),
261 MessageContent::Thinking { .. } => {}
262 MessageContent::RedactedThinking(_) => {}
263 MessageContent::Image(_) => {}
264 MessageContent::ToolUse(tool_use) => {
265 let tool_call = lmstudio::ToolCall {
266 id: tool_use.id.to_string(),
267 content: lmstudio::ToolCallContent::Function {
268 function: lmstudio::FunctionContent {
269 name: tool_use.name.to_string(),
270 arguments: serde_json::to_string(&tool_use.input)
271 .unwrap_or_default(),
272 },
273 },
274 };
275
276 if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
277 messages.last_mut()
278 {
279 tool_calls.push(tool_call);
280 } else {
281 messages.push(lmstudio::ChatMessage::Assistant {
282 content: None,
283 tool_calls: vec![tool_call],
284 });
285 }
286 }
287 MessageContent::ToolResult(tool_result) => {
288 match &tool_result.content {
289 LanguageModelToolResultContent::Text(text) => {
290 messages.push(lmstudio::ChatMessage::Tool {
291 content: text.to_string(),
292 tool_call_id: tool_result.tool_use_id.to_string(),
293 });
294 }
295 LanguageModelToolResultContent::Image(_) => {
296 // no support for images for now
297 }
298 };
299 }
300 }
301 }
302 }
303
304 ChatCompletionRequest {
305 model: self.model.name.clone(),
306 messages,
307 stream: true,
308 max_tokens: Some(-1),
309 stop: Some(request.stop),
310 // In LM Studio you can configure specific settings you'd like to use for your model.
311 // For example Qwen3 is recommended to be used with 0.7 temperature.
312 // It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
313 temperature: request.temperature.or(None),
314 tools: request
315 .tools
316 .into_iter()
317 .map(|tool| lmstudio::ToolDefinition::Function {
318 function: lmstudio::FunctionDefinition {
319 name: tool.name,
320 description: Some(tool.description),
321 parameters: Some(tool.input_schema),
322 },
323 })
324 .collect(),
325 tool_choice: request.tool_choice.map(|choice| match choice {
326 LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
327 LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
328 LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
329 }),
330 }
331 }
332
333 fn stream_completion(
334 &self,
335 request: ChatCompletionRequest,
336 cx: &AsyncApp,
337 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
338 {
339 let http_client = self.http_client.clone();
340 let Ok(api_url) = cx.update(|cx| {
341 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
342 settings.api_url.clone()
343 }) else {
344 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
345 };
346
347 let future = self.request_limiter.stream(async move {
348 let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
349 let response = request.await?;
350 Ok(response)
351 });
352
353 async move { Ok(future.await?.boxed()) }.boxed()
354 }
355}
356
357impl LanguageModel for LmStudioLanguageModel {
358 fn id(&self) -> LanguageModelId {
359 self.id.clone()
360 }
361
362 fn name(&self) -> LanguageModelName {
363 LanguageModelName::from(self.model.display_name().to_string())
364 }
365
366 fn provider_id(&self) -> LanguageModelProviderId {
367 LanguageModelProviderId(PROVIDER_ID.into())
368 }
369
370 fn provider_name(&self) -> LanguageModelProviderName {
371 LanguageModelProviderName(PROVIDER_NAME.into())
372 }
373
374 fn supports_tools(&self) -> bool {
375 self.model.supports_tool_calls()
376 }
377
378 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
379 self.supports_tools()
380 && match choice {
381 LanguageModelToolChoice::Auto => true,
382 LanguageModelToolChoice::Any => true,
383 LanguageModelToolChoice::None => true,
384 }
385 }
386
387 fn supports_images(&self) -> bool {
388 false
389 }
390
391 fn telemetry_id(&self) -> String {
392 format!("lmstudio/{}", self.model.id())
393 }
394
395 fn max_token_count(&self) -> usize {
396 self.model.max_token_count()
397 }
398
399 fn count_tokens(
400 &self,
401 request: LanguageModelRequest,
402 _cx: &App,
403 ) -> BoxFuture<'static, Result<usize>> {
404 // Endpoint for this is coming soon. In the meantime, hacky estimation
405 let token_count = request
406 .messages
407 .iter()
408 .map(|msg| msg.string_contents().split_whitespace().count())
409 .sum::<usize>();
410
411 let estimated_tokens = (token_count as f64 * 0.75) as usize;
412 async move { Ok(estimated_tokens) }.boxed()
413 }
414
415 fn stream_completion(
416 &self,
417 request: LanguageModelRequest,
418 cx: &AsyncApp,
419 ) -> BoxFuture<
420 'static,
421 Result<
422 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
423 LanguageModelCompletionError,
424 >,
425 > {
426 let request = self.to_lmstudio_request(request);
427 let completions = self.stream_completion(request, cx);
428 async move {
429 let mapper = LmStudioEventMapper::new();
430 Ok(mapper.map_stream(completions.await?).boxed())
431 }
432 .boxed()
433 }
434}
435
436struct LmStudioEventMapper {
437 tool_calls_by_index: HashMap<usize, RawToolCall>,
438}
439
440impl LmStudioEventMapper {
441 fn new() -> Self {
442 Self {
443 tool_calls_by_index: HashMap::default(),
444 }
445 }
446
447 pub fn map_stream(
448 mut self,
449 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
450 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
451 {
452 events.flat_map(move |event| {
453 futures::stream::iter(match event {
454 Ok(event) => self.map_event(event),
455 Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
456 })
457 })
458 }
459
460 pub fn map_event(
461 &mut self,
462 event: ResponseStreamEvent,
463 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
464 let Some(choice) = event.choices.into_iter().next() else {
465 return vec![Err(LanguageModelCompletionError::Other(anyhow!(
466 "Response contained no choices"
467 )))];
468 };
469
470 let mut events = Vec::new();
471 if let Some(content) = choice.delta.content {
472 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
473 }
474
475 if let Some(reasoning_content) = choice.delta.reasoning_content {
476 events.push(Ok(LanguageModelCompletionEvent::Thinking {
477 text: reasoning_content,
478 signature: None,
479 }));
480 }
481
482 if let Some(tool_calls) = choice.delta.tool_calls {
483 for tool_call in tool_calls {
484 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
485
486 if let Some(tool_id) = tool_call.id {
487 entry.id = tool_id;
488 }
489
490 if let Some(function) = tool_call.function {
491 if let Some(name) = function.name {
492 // At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
493 // 1. It sends function name in the first chunk
494 // 2. It sends empty string in the function name field in all subsequent chunks for arguments
495 // According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
496 // function name field should be sent only inside the first chunk.
497 if !name.is_empty() {
498 entry.name = name;
499 }
500 }
501
502 if let Some(arguments) = function.arguments {
503 entry.arguments.push_str(&arguments);
504 }
505 }
506 }
507 }
508
509 match choice.finish_reason.as_deref() {
510 Some("stop") => {
511 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
512 }
513 Some("tool_calls") => {
514 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
515 match serde_json::Value::from_str(&tool_call.arguments) {
516 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
517 LanguageModelToolUse {
518 id: tool_call.id.into(),
519 name: tool_call.name.into(),
520 is_input_complete: true,
521 input,
522 raw_input: tool_call.arguments,
523 },
524 )),
525 Err(error) => Err(LanguageModelCompletionError::BadInputJson {
526 id: tool_call.id.into(),
527 tool_name: tool_call.name.into(),
528 raw_input: tool_call.arguments.into(),
529 json_parse_error: error.to_string(),
530 }),
531 }
532 }));
533
534 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
535 }
536 Some(stop_reason) => {
537 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
538 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
539 }
540 None => {}
541 }
542
543 events
544 }
545}
546
547#[derive(Default)]
548struct RawToolCall {
549 id: String,
550 name: String,
551 arguments: String,
552}
553
554struct ConfigurationView {
555 state: gpui::Entity<State>,
556 loading_models_task: Option<Task<()>>,
557}
558
559impl ConfigurationView {
560 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
561 let loading_models_task = Some(cx.spawn({
562 let state = state.clone();
563 async move |this, cx| {
564 if let Some(task) = state
565 .update(cx, |state, cx| state.authenticate(cx))
566 .log_err()
567 {
568 task.await.log_err();
569 }
570 this.update(cx, |this, cx| {
571 this.loading_models_task = None;
572 cx.notify();
573 })
574 .log_err();
575 }
576 }));
577
578 Self {
579 state,
580 loading_models_task,
581 }
582 }
583
584 fn retry_connection(&self, cx: &mut App) {
585 self.state
586 .update(cx, |state, cx| state.fetch_models(cx))
587 .detach_and_log_err(cx);
588 }
589}
590
591impl Render for ConfigurationView {
592 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
593 let is_authenticated = self.state.read(cx).is_authenticated();
594
595 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
596
597 if self.loading_models_task.is_some() {
598 div().child(Label::new("Loading models...")).into_any()
599 } else {
600 v_flex()
601 .gap_2()
602 .child(
603 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
604 List::new()
605 .child(InstructionListItem::text_only(
606 "LM Studio needs to be running with at least one model downloaded.",
607 ))
608 .child(InstructionListItem::text_only(
609 "To get your first model, try running `lms get qwen2.5-coder-7b`",
610 )),
611 ),
612 )
613 .child(
614 h_flex()
615 .w_full()
616 .justify_between()
617 .gap_2()
618 .child(
619 h_flex()
620 .w_full()
621 .gap_2()
622 .map(|this| {
623 if is_authenticated {
624 this.child(
625 Button::new("lmstudio-site", "LM Studio")
626 .style(ButtonStyle::Subtle)
627 .icon(IconName::ArrowUpRight)
628 .icon_size(IconSize::XSmall)
629 .icon_color(Color::Muted)
630 .on_click(move |_, _window, cx| {
631 cx.open_url(LMSTUDIO_SITE)
632 })
633 .into_any_element(),
634 )
635 } else {
636 this.child(
637 Button::new(
638 "download_lmstudio_button",
639 "Download LM Studio",
640 )
641 .style(ButtonStyle::Subtle)
642 .icon(IconName::ArrowUpRight)
643 .icon_size(IconSize::XSmall)
644 .icon_color(Color::Muted)
645 .on_click(move |_, _window, cx| {
646 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
647 })
648 .into_any_element(),
649 )
650 }
651 })
652 .child(
653 Button::new("view-models", "Model Catalog")
654 .style(ButtonStyle::Subtle)
655 .icon(IconName::ArrowUpRight)
656 .icon_size(IconSize::XSmall)
657 .icon_color(Color::Muted)
658 .on_click(move |_, _window, cx| {
659 cx.open_url(LMSTUDIO_CATALOG_URL)
660 }),
661 ),
662 )
663 .map(|this| {
664 if is_authenticated {
665 this.child(
666 ButtonLike::new("connected")
667 .disabled(true)
668 .cursor_style(gpui::CursorStyle::Arrow)
669 .child(
670 h_flex()
671 .gap_2()
672 .child(Indicator::dot().color(Color::Success))
673 .child(Label::new("Connected"))
674 .into_any_element(),
675 ),
676 )
677 } else {
678 this.child(
679 Button::new("retry_lmstudio_models", "Connect")
680 .icon_position(IconPosition::Start)
681 .icon_size(IconSize::XSmall)
682 .icon(IconName::Play)
683 .on_click(cx.listener(move |this, _, _window, cx| {
684 this.retry_connection(cx)
685 })),
686 )
687 }
688 }),
689 )
690 .into_any()
691 }
692 }
693}