lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use lmstudio::{
 12    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 13    stream_chat_completion,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{ButtonLike, Indicator, prelude::*};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23
 24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 27
 28const PROVIDER_ID: &str = "lmstudio";
 29const PROVIDER_NAME: &str = "LM Studio";
 30
 31#[derive(Default, Debug, Clone, PartialEq)]
 32pub struct LmStudioSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The model's context window size.
 44    pub max_tokens: usize,
 45}
 46
 47pub struct LmStudioLanguageModelProvider {
 48    http_client: Arc<dyn HttpClient>,
 49    state: gpui::Entity<State>,
 50}
 51
 52pub struct State {
 53    http_client: Arc<dyn HttpClient>,
 54    available_models: Vec<lmstudio::Model>,
 55    fetch_model_task: Option<Task<Result<()>>>,
 56    _subscription: Subscription,
 57}
 58
 59impl State {
 60    fn is_authenticated(&self) -> bool {
 61        !self.available_models.is_empty()
 62    }
 63
 64    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 65        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 66        let http_client = self.http_client.clone();
 67        let api_url = settings.api_url.clone();
 68
 69        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 70        cx.spawn(async move |this, cx| {
 71            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 72
 73            let mut models: Vec<lmstudio::Model> = models
 74                .into_iter()
 75                .filter(|model| model.r#type != ModelType::Embeddings)
 76                .map(|model| lmstudio::Model::new(&model.id, None, None))
 77                .collect();
 78
 79            models.sort_by(|a, b| a.name.cmp(&b.name));
 80
 81            this.update(cx, |this, cx| {
 82                this.available_models = models;
 83                cx.notify();
 84            })
 85        })
 86    }
 87
 88    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 89        let task = self.fetch_models(cx);
 90        self.fetch_model_task.replace(task);
 91    }
 92
 93    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 94        if self.is_authenticated() {
 95            return Task::ready(Ok(()));
 96        }
 97
 98        let fetch_models_task = self.fetch_models(cx);
 99        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
100    }
101}
102
103impl LmStudioLanguageModelProvider {
104    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105        let this = Self {
106            http_client: http_client.clone(),
107            state: cx.new(|cx| {
108                let subscription = cx.observe_global::<SettingsStore>({
109                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110                    move |this: &mut State, cx| {
111                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112                        if &settings != new_settings {
113                            settings = new_settings.clone();
114                            this.restart_fetch_models_task(cx);
115                            cx.notify();
116                        }
117                    }
118                });
119
120                State {
121                    http_client,
122                    available_models: Default::default(),
123                    fetch_model_task: None,
124                    _subscription: subscription,
125                }
126            }),
127        };
128        this.state
129            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130        this
131    }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135    type ObservableEntity = State;
136
137    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138        Some(self.state.clone())
139    }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143    fn id(&self) -> LanguageModelProviderId {
144        LanguageModelProviderId(PROVIDER_ID.into())
145    }
146
147    fn name(&self) -> LanguageModelProviderName {
148        LanguageModelProviderName(PROVIDER_NAME.into())
149    }
150
151    fn icon(&self) -> IconName {
152        IconName::AiLmStudio
153    }
154
155    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
156        self.provided_models(cx).into_iter().next()
157    }
158
159    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
161
162        // Add models from the LM Studio API
163        for model in self.state.read(cx).available_models.iter() {
164            models.insert(model.name.clone(), model.clone());
165        }
166
167        // Override with available models from settings
168        for model in AllLanguageModelSettings::get_global(cx)
169            .lmstudio
170            .available_models
171            .iter()
172        {
173            models.insert(
174                model.name.clone(),
175                lmstudio::Model {
176                    name: model.name.clone(),
177                    display_name: model.display_name.clone(),
178                    max_tokens: model.max_tokens,
179                },
180            );
181        }
182
183        models
184            .into_values()
185            .map(|model| {
186                Arc::new(LmStudioLanguageModel {
187                    id: LanguageModelId::from(model.name.clone()),
188                    model: model.clone(),
189                    http_client: self.http_client.clone(),
190                    request_limiter: RateLimiter::new(4),
191                }) as Arc<dyn LanguageModel>
192            })
193            .collect()
194    }
195
196    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
197        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
198        let http_client = self.http_client.clone();
199        let api_url = settings.api_url.clone();
200        let id = model.id().0.to_string();
201        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
202            .detach_and_log_err(cx);
203    }
204
205    fn is_authenticated(&self, cx: &App) -> bool {
206        self.state.read(cx).is_authenticated()
207    }
208
209    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
210        self.state.update(cx, |state, cx| state.authenticate(cx))
211    }
212
213    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
214        let state = self.state.clone();
215        cx.new(|cx| ConfigurationView::new(state, cx)).into()
216    }
217
218    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219        self.state.update(cx, |state, cx| state.fetch_models(cx))
220    }
221}
222
223pub struct LmStudioLanguageModel {
224    id: LanguageModelId,
225    model: lmstudio::Model,
226    http_client: Arc<dyn HttpClient>,
227    request_limiter: RateLimiter,
228}
229
230impl LmStudioLanguageModel {
231    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
232        ChatCompletionRequest {
233            model: self.model.name.clone(),
234            messages: request
235                .messages
236                .into_iter()
237                .map(|msg| match msg.role {
238                    Role::User => ChatMessage::User {
239                        content: msg.string_contents(),
240                    },
241                    Role::Assistant => ChatMessage::Assistant {
242                        content: Some(msg.string_contents()),
243                        tool_calls: None,
244                    },
245                    Role::System => ChatMessage::System {
246                        content: msg.string_contents(),
247                    },
248                })
249                .collect(),
250            stream: true,
251            max_tokens: Some(-1),
252            stop: Some(request.stop),
253            temperature: request.temperature.or(Some(0.0)),
254            tools: vec![],
255        }
256    }
257}
258
259impl LanguageModel for LmStudioLanguageModel {
260    fn id(&self) -> LanguageModelId {
261        self.id.clone()
262    }
263
264    fn name(&self) -> LanguageModelName {
265        LanguageModelName::from(self.model.display_name().to_string())
266    }
267
268    fn provider_id(&self) -> LanguageModelProviderId {
269        LanguageModelProviderId(PROVIDER_ID.into())
270    }
271
272    fn provider_name(&self) -> LanguageModelProviderName {
273        LanguageModelProviderName(PROVIDER_NAME.into())
274    }
275
276    fn supports_tools(&self) -> bool {
277        false
278    }
279
280    fn telemetry_id(&self) -> String {
281        format!("lmstudio/{}", self.model.id())
282    }
283
284    fn max_token_count(&self) -> usize {
285        self.model.max_token_count()
286    }
287
288    fn count_tokens(
289        &self,
290        request: LanguageModelRequest,
291        _cx: &App,
292    ) -> BoxFuture<'static, Result<usize>> {
293        // Endpoint for this is coming soon. In the meantime, hacky estimation
294        let token_count = request
295            .messages
296            .iter()
297            .map(|msg| msg.string_contents().split_whitespace().count())
298            .sum::<usize>();
299
300        let estimated_tokens = (token_count as f64 * 0.75) as usize;
301        async move { Ok(estimated_tokens) }.boxed()
302    }
303
304    fn stream_completion(
305        &self,
306        request: LanguageModelRequest,
307        cx: &AsyncApp,
308    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
309        let request = self.to_lmstudio_request(request);
310
311        let http_client = self.http_client.clone();
312        let Ok(api_url) = cx.update(|cx| {
313            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
314            settings.api_url.clone()
315        }) else {
316            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
317        };
318
319        let future = self.request_limiter.stream(async move {
320            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
321            let stream = response
322                .filter_map(|response| async move {
323                    match response {
324                        Ok(fragment) => {
325                            // Skip empty deltas
326                            if fragment.choices[0].delta.is_object()
327                                && fragment.choices[0].delta.as_object().unwrap().is_empty()
328                            {
329                                return None;
330                            }
331
332                            // Try to parse the delta as ChatMessage
333                            if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
334                                fragment.choices[0].delta.clone(),
335                            ) {
336                                let content = match chat_message {
337                                    ChatMessage::User { content } => content,
338                                    ChatMessage::Assistant { content, .. } => {
339                                        content.unwrap_or_default()
340                                    }
341                                    ChatMessage::System { content } => content,
342                                };
343                                if !content.is_empty() {
344                                    Some(Ok(content))
345                                } else {
346                                    None
347                                }
348                            } else {
349                                None
350                            }
351                        }
352                        Err(error) => Some(Err(error)),
353                    }
354                })
355                .boxed();
356            Ok(stream)
357        });
358
359        async move {
360            Ok(future
361                .await?
362                .map(|result| result.map(LanguageModelCompletionEvent::Text))
363                .boxed())
364        }
365        .boxed()
366    }
367}
368
369struct ConfigurationView {
370    state: gpui::Entity<State>,
371    loading_models_task: Option<Task<()>>,
372}
373
374impl ConfigurationView {
375    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
376        let loading_models_task = Some(cx.spawn({
377            let state = state.clone();
378            async move |this, cx| {
379                if let Some(task) = state
380                    .update(cx, |state, cx| state.authenticate(cx))
381                    .log_err()
382                {
383                    task.await.log_err();
384                }
385                this.update(cx, |this, cx| {
386                    this.loading_models_task = None;
387                    cx.notify();
388                })
389                .log_err();
390            }
391        }));
392
393        Self {
394            state,
395            loading_models_task,
396        }
397    }
398
399    fn retry_connection(&self, cx: &mut App) {
400        self.state
401            .update(cx, |state, cx| state.fetch_models(cx))
402            .detach_and_log_err(cx);
403    }
404}
405
406impl Render for ConfigurationView {
407    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
408        let is_authenticated = self.state.read(cx).is_authenticated();
409
410        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
411        let lmstudio_reqs = "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
412
413        let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
414
415        if self.loading_models_task.is_some() {
416            div().child(Label::new("Loading models...")).into_any()
417        } else {
418            v_flex()
419                .size_full()
420                .gap_3()
421                .child(
422                    v_flex()
423                        .size_full()
424                        .gap_2()
425                        .p_1()
426                        .child(Label::new(lmstudio_intro))
427                        .child(Label::new(lmstudio_reqs))
428                        .child(
429                            h_flex()
430                                .gap_0p5()
431                                .child(Label::new("To get your first model, try running"))
432                                .child(
433                                    div()
434                                        .bg(inline_code_bg)
435                                        .px_1p5()
436                                        .rounded_sm()
437                                        .child(Label::new("lms get qwen2.5-coder-7b")),
438                                ),
439                        ),
440                )
441                .child(
442                    h_flex()
443                        .w_full()
444                        .pt_2()
445                        .justify_between()
446                        .gap_2()
447                        .child(
448                            h_flex()
449                                .w_full()
450                                .gap_2()
451                                .map(|this| {
452                                    if is_authenticated {
453                                        this.child(
454                                            Button::new("lmstudio-site", "LM Studio")
455                                                .style(ButtonStyle::Subtle)
456                                                .icon(IconName::ArrowUpRight)
457                                                .icon_size(IconSize::XSmall)
458                                                .icon_color(Color::Muted)
459                                                .on_click(move |_, _window, cx| {
460                                                    cx.open_url(LMSTUDIO_SITE)
461                                                })
462                                                .into_any_element(),
463                                        )
464                                    } else {
465                                        this.child(
466                                            Button::new(
467                                                "download_lmstudio_button",
468                                                "Download LM Studio",
469                                            )
470                                            .style(ButtonStyle::Subtle)
471                                            .icon(IconName::ArrowUpRight)
472                                            .icon_size(IconSize::XSmall)
473                                            .icon_color(Color::Muted)
474                                            .on_click(move |_, _window, cx| {
475                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
476                                            })
477                                            .into_any_element(),
478                                        )
479                                    }
480                                })
481                                .child(
482                                    Button::new("view-models", "Model Catalog")
483                                        .style(ButtonStyle::Subtle)
484                                        .icon(IconName::ArrowUpRight)
485                                        .icon_size(IconSize::XSmall)
486                                        .icon_color(Color::Muted)
487                                        .on_click(move |_, _window, cx| {
488                                            cx.open_url(LMSTUDIO_CATALOG_URL)
489                                        }),
490                                ),
491                        )
492                        .child(if is_authenticated {
493                            // This is only a button to ensure the spacing is correct
494                            // it should stay disabled
495                            ButtonLike::new("connected")
496                                .disabled(true)
497                                // Since this won't ever be clickable, we can use the arrow cursor
498                                .cursor_style(gpui::CursorStyle::Arrow)
499                                .child(
500                                    h_flex()
501                                        .gap_2()
502                                        .child(Indicator::dot().color(Color::Success))
503                                        .child(Label::new("Connected"))
504                                        .into_any_element(),
505                                )
506                                .into_any_element()
507                        } else {
508                            Button::new("retry_lmstudio_models", "Connect")
509                                .icon_position(IconPosition::Start)
510                                .icon(IconName::ArrowCircle)
511                                .on_click(cx.listener(move |this, _, _window, cx| {
512                                    this.retry_connection(cx)
513                                }))
514                                .into_any_element()
515                        }),
516                )
517                .into_any()
518        }
519    }
520}