1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{
6 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
7};
8use language_model::{
9 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
10 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
11 LanguageModelRequest, RateLimiter, Role,
12};
13use lmstudio::{
14 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
15 stream_chat_completion,
16};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use settings::{Settings, SettingsStore};
20use std::{collections::BTreeMap, sync::Arc};
21use ui::{ButtonLike, Indicator, List, prelude::*};
22use util::ResultExt;
23
24use crate::AllLanguageModelSettings;
25use crate::ui::InstructionListItem;
26
27const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
28const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
29const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
30
31const PROVIDER_ID: &str = "lmstudio";
32const PROVIDER_NAME: &str = "LM Studio";
33
34#[derive(Default, Debug, Clone, PartialEq)]
35pub struct LmStudioSettings {
36 pub api_url: String,
37 pub available_models: Vec<AvailableModel>,
38}
39
40#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
41pub struct AvailableModel {
42 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
43 pub name: String,
44 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
45 pub display_name: Option<String>,
46 /// The model's context window size.
47 pub max_tokens: usize,
48}
49
50pub struct LmStudioLanguageModelProvider {
51 http_client: Arc<dyn HttpClient>,
52 state: gpui::Entity<State>,
53}
54
55pub struct State {
56 http_client: Arc<dyn HttpClient>,
57 available_models: Vec<lmstudio::Model>,
58 fetch_model_task: Option<Task<Result<()>>>,
59 _subscription: Subscription,
60}
61
62impl State {
63 fn is_authenticated(&self) -> bool {
64 !self.available_models.is_empty()
65 }
66
67 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
68 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
69 let http_client = self.http_client.clone();
70 let api_url = settings.api_url.clone();
71
72 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
73 cx.spawn(async move |this, cx| {
74 let models = get_models(http_client.as_ref(), &api_url, None).await?;
75
76 let mut models: Vec<lmstudio::Model> = models
77 .into_iter()
78 .filter(|model| model.r#type != ModelType::Embeddings)
79 .map(|model| lmstudio::Model::new(&model.id, None, None))
80 .collect();
81
82 models.sort_by(|a, b| a.name.cmp(&b.name));
83
84 this.update(cx, |this, cx| {
85 this.available_models = models;
86 cx.notify();
87 })
88 })
89 }
90
91 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
92 let task = self.fetch_models(cx);
93 self.fetch_model_task.replace(task);
94 }
95
96 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
97 if self.is_authenticated() {
98 return Task::ready(Ok(()));
99 }
100
101 let fetch_models_task = self.fetch_models(cx);
102 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
103 }
104}
105
106impl LmStudioLanguageModelProvider {
107 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
108 let this = Self {
109 http_client: http_client.clone(),
110 state: cx.new(|cx| {
111 let subscription = cx.observe_global::<SettingsStore>({
112 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
113 move |this: &mut State, cx| {
114 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
115 if &settings != new_settings {
116 settings = new_settings.clone();
117 this.restart_fetch_models_task(cx);
118 cx.notify();
119 }
120 }
121 });
122
123 State {
124 http_client,
125 available_models: Default::default(),
126 fetch_model_task: None,
127 _subscription: subscription,
128 }
129 }),
130 };
131 this.state
132 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
133 this
134 }
135}
136
137impl LanguageModelProviderState for LmStudioLanguageModelProvider {
138 type ObservableEntity = State;
139
140 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
141 Some(self.state.clone())
142 }
143}
144
145impl LanguageModelProvider for LmStudioLanguageModelProvider {
146 fn id(&self) -> LanguageModelProviderId {
147 LanguageModelProviderId(PROVIDER_ID.into())
148 }
149
150 fn name(&self) -> LanguageModelProviderName {
151 LanguageModelProviderName(PROVIDER_NAME.into())
152 }
153
154 fn icon(&self) -> IconName {
155 IconName::AiLmStudio
156 }
157
158 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
159 self.provided_models(cx).into_iter().next()
160 }
161
162 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
163 self.default_model(cx)
164 }
165
166 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
167 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
168
169 // Add models from the LM Studio API
170 for model in self.state.read(cx).available_models.iter() {
171 models.insert(model.name.clone(), model.clone());
172 }
173
174 // Override with available models from settings
175 for model in AllLanguageModelSettings::get_global(cx)
176 .lmstudio
177 .available_models
178 .iter()
179 {
180 models.insert(
181 model.name.clone(),
182 lmstudio::Model {
183 name: model.name.clone(),
184 display_name: model.display_name.clone(),
185 max_tokens: model.max_tokens,
186 },
187 );
188 }
189
190 models
191 .into_values()
192 .map(|model| {
193 Arc::new(LmStudioLanguageModel {
194 id: LanguageModelId::from(model.name.clone()),
195 model: model.clone(),
196 http_client: self.http_client.clone(),
197 request_limiter: RateLimiter::new(4),
198 }) as Arc<dyn LanguageModel>
199 })
200 .collect()
201 }
202
203 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
204 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
205 let http_client = self.http_client.clone();
206 let api_url = settings.api_url.clone();
207 let id = model.id().0.to_string();
208 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
209 .detach_and_log_err(cx);
210 }
211
212 fn is_authenticated(&self, cx: &App) -> bool {
213 self.state.read(cx).is_authenticated()
214 }
215
216 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
217 self.state.update(cx, |state, cx| state.authenticate(cx))
218 }
219
220 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
221 let state = self.state.clone();
222 cx.new(|cx| ConfigurationView::new(state, cx)).into()
223 }
224
225 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
226 self.state.update(cx, |state, cx| state.fetch_models(cx))
227 }
228}
229
230pub struct LmStudioLanguageModel {
231 id: LanguageModelId,
232 model: lmstudio::Model,
233 http_client: Arc<dyn HttpClient>,
234 request_limiter: RateLimiter,
235}
236
237impl LmStudioLanguageModel {
238 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
239 ChatCompletionRequest {
240 model: self.model.name.clone(),
241 messages: request
242 .messages
243 .into_iter()
244 .map(|msg| match msg.role {
245 Role::User => ChatMessage::User {
246 content: msg.string_contents(),
247 },
248 Role::Assistant => ChatMessage::Assistant {
249 content: Some(msg.string_contents()),
250 tool_calls: None,
251 },
252 Role::System => ChatMessage::System {
253 content: msg.string_contents(),
254 },
255 })
256 .collect(),
257 stream: true,
258 max_tokens: Some(-1),
259 stop: Some(request.stop),
260 temperature: request.temperature.or(Some(0.0)),
261 tools: vec![],
262 }
263 }
264}
265
266impl LanguageModel for LmStudioLanguageModel {
267 fn id(&self) -> LanguageModelId {
268 self.id.clone()
269 }
270
271 fn name(&self) -> LanguageModelName {
272 LanguageModelName::from(self.model.display_name().to_string())
273 }
274
275 fn provider_id(&self) -> LanguageModelProviderId {
276 LanguageModelProviderId(PROVIDER_ID.into())
277 }
278
279 fn provider_name(&self) -> LanguageModelProviderName {
280 LanguageModelProviderName(PROVIDER_NAME.into())
281 }
282
283 fn supports_tools(&self) -> bool {
284 false
285 }
286
287 fn telemetry_id(&self) -> String {
288 format!("lmstudio/{}", self.model.id())
289 }
290
291 fn max_token_count(&self) -> usize {
292 self.model.max_token_count()
293 }
294
295 fn count_tokens(
296 &self,
297 request: LanguageModelRequest,
298 _cx: &App,
299 ) -> BoxFuture<'static, Result<usize>> {
300 // Endpoint for this is coming soon. In the meantime, hacky estimation
301 let token_count = request
302 .messages
303 .iter()
304 .map(|msg| msg.string_contents().split_whitespace().count())
305 .sum::<usize>();
306
307 let estimated_tokens = (token_count as f64 * 0.75) as usize;
308 async move { Ok(estimated_tokens) }.boxed()
309 }
310
311 fn stream_completion(
312 &self,
313 request: LanguageModelRequest,
314 cx: &AsyncApp,
315 ) -> BoxFuture<
316 'static,
317 Result<
318 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
319 >,
320 > {
321 let request = self.to_lmstudio_request(request);
322
323 let http_client = self.http_client.clone();
324 let Ok(api_url) = cx.update(|cx| {
325 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
326 settings.api_url.clone()
327 }) else {
328 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
329 };
330
331 let future = self.request_limiter.stream(async move {
332 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
333
334 // Create a stream mapper to handle content across multiple deltas
335 let stream_mapper = LmStudioStreamMapper::new();
336
337 let stream = response
338 .map(move |response| {
339 response.and_then(|fragment| stream_mapper.process_fragment(fragment))
340 })
341 .filter_map(|result| async move {
342 match result {
343 Ok(Some(content)) => Some(Ok(content)),
344 Ok(None) => None,
345 Err(error) => Some(Err(error)),
346 }
347 })
348 .boxed();
349
350 Ok(stream)
351 });
352
353 async move {
354 Ok(future
355 .await?
356 .map(|result| {
357 result
358 .map(LanguageModelCompletionEvent::Text)
359 .map_err(LanguageModelCompletionError::Other)
360 })
361 .boxed())
362 }
363 .boxed()
364 }
365}
366
367// This will be more useful when we implement tool calling. Currently keeping it empty.
368struct LmStudioStreamMapper {}
369
370impl LmStudioStreamMapper {
371 fn new() -> Self {
372 Self {}
373 }
374
375 fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
376 // Most of the time, there will be only one choice
377 let Some(choice) = fragment.choices.first() else {
378 return Ok(None);
379 };
380
381 // Extract the delta content
382 if let Ok(delta) =
383 serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
384 {
385 if let Some(content) = delta.content {
386 if !content.is_empty() {
387 return Ok(Some(content));
388 }
389 }
390 }
391
392 // If there's a finish_reason, we're done
393 if choice.finish_reason.is_some() {
394 return Ok(None);
395 }
396
397 Ok(None)
398 }
399}
400
401struct ConfigurationView {
402 state: gpui::Entity<State>,
403 loading_models_task: Option<Task<()>>,
404}
405
406impl ConfigurationView {
407 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
408 let loading_models_task = Some(cx.spawn({
409 let state = state.clone();
410 async move |this, cx| {
411 if let Some(task) = state
412 .update(cx, |state, cx| state.authenticate(cx))
413 .log_err()
414 {
415 task.await.log_err();
416 }
417 this.update(cx, |this, cx| {
418 this.loading_models_task = None;
419 cx.notify();
420 })
421 .log_err();
422 }
423 }));
424
425 Self {
426 state,
427 loading_models_task,
428 }
429 }
430
431 fn retry_connection(&self, cx: &mut App) {
432 self.state
433 .update(cx, |state, cx| state.fetch_models(cx))
434 .detach_and_log_err(cx);
435 }
436}
437
438impl Render for ConfigurationView {
439 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
440 let is_authenticated = self.state.read(cx).is_authenticated();
441
442 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
443
444 if self.loading_models_task.is_some() {
445 div().child(Label::new("Loading models...")).into_any()
446 } else {
447 v_flex()
448 .gap_2()
449 .child(
450 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
451 List::new()
452 .child(InstructionListItem::text_only(
453 "LM Studio needs to be running with at least one model downloaded.",
454 ))
455 .child(InstructionListItem::text_only(
456 "To get your first model, try running `lms get qwen2.5-coder-7b`",
457 )),
458 ),
459 )
460 .child(
461 h_flex()
462 .w_full()
463 .justify_between()
464 .gap_2()
465 .child(
466 h_flex()
467 .w_full()
468 .gap_2()
469 .map(|this| {
470 if is_authenticated {
471 this.child(
472 Button::new("lmstudio-site", "LM Studio")
473 .style(ButtonStyle::Subtle)
474 .icon(IconName::ArrowUpRight)
475 .icon_size(IconSize::XSmall)
476 .icon_color(Color::Muted)
477 .on_click(move |_, _window, cx| {
478 cx.open_url(LMSTUDIO_SITE)
479 })
480 .into_any_element(),
481 )
482 } else {
483 this.child(
484 Button::new(
485 "download_lmstudio_button",
486 "Download LM Studio",
487 )
488 .style(ButtonStyle::Subtle)
489 .icon(IconName::ArrowUpRight)
490 .icon_size(IconSize::XSmall)
491 .icon_color(Color::Muted)
492 .on_click(move |_, _window, cx| {
493 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
494 })
495 .into_any_element(),
496 )
497 }
498 })
499 .child(
500 Button::new("view-models", "Model Catalog")
501 .style(ButtonStyle::Subtle)
502 .icon(IconName::ArrowUpRight)
503 .icon_size(IconSize::XSmall)
504 .icon_color(Color::Muted)
505 .on_click(move |_, _window, cx| {
506 cx.open_url(LMSTUDIO_CATALOG_URL)
507 }),
508 ),
509 )
510 .map(|this| {
511 if is_authenticated {
512 this.child(
513 ButtonLike::new("connected")
514 .disabled(true)
515 .cursor_style(gpui::CursorStyle::Arrow)
516 .child(
517 h_flex()
518 .gap_2()
519 .child(Indicator::dot().color(Color::Success))
520 .child(Label::new("Connected"))
521 .into_any_element(),
522 ),
523 )
524 } else {
525 this.child(
526 Button::new("retry_lmstudio_models", "Connect")
527 .icon_position(IconPosition::Start)
528 .icon_size(IconSize::XSmall)
529 .icon(IconName::Play)
530 .on_click(cx.listener(move |this, _, _window, cx| {
531 this.retry_connection(cx)
532 })),
533 )
534 }
535 }),
536 )
537 .into_any()
538 }
539 }
540}