lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use lmstudio::{
 12    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 13    stream_chat_completion,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{ButtonLike, Indicator, List, prelude::*};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23use crate::ui::InstructionListItem;
 24
 25const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 26const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 27const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 28
 29const PROVIDER_ID: &str = "lmstudio";
 30const PROVIDER_NAME: &str = "LM Studio";
 31
 32#[derive(Default, Debug, Clone, PartialEq)]
 33pub struct LmStudioSettings {
 34    pub api_url: String,
 35    pub available_models: Vec<AvailableModel>,
 36}
 37
 38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 39pub struct AvailableModel {
 40    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 41    pub name: String,
 42    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 43    pub display_name: Option<String>,
 44    /// The model's context window size.
 45    pub max_tokens: usize,
 46}
 47
 48pub struct LmStudioLanguageModelProvider {
 49    http_client: Arc<dyn HttpClient>,
 50    state: gpui::Entity<State>,
 51}
 52
 53pub struct State {
 54    http_client: Arc<dyn HttpClient>,
 55    available_models: Vec<lmstudio::Model>,
 56    fetch_model_task: Option<Task<Result<()>>>,
 57    _subscription: Subscription,
 58}
 59
 60impl State {
 61    fn is_authenticated(&self) -> bool {
 62        !self.available_models.is_empty()
 63    }
 64
 65    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 66        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 67        let http_client = self.http_client.clone();
 68        let api_url = settings.api_url.clone();
 69
 70        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 71        cx.spawn(async move |this, cx| {
 72            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 73
 74            let mut models: Vec<lmstudio::Model> = models
 75                .into_iter()
 76                .filter(|model| model.r#type != ModelType::Embeddings)
 77                .map(|model| lmstudio::Model::new(&model.id, None, None))
 78                .collect();
 79
 80            models.sort_by(|a, b| a.name.cmp(&b.name));
 81
 82            this.update(cx, |this, cx| {
 83                this.available_models = models;
 84                cx.notify();
 85            })
 86        })
 87    }
 88
 89    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 90        let task = self.fetch_models(cx);
 91        self.fetch_model_task.replace(task);
 92    }
 93
 94    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 95        if self.is_authenticated() {
 96            return Task::ready(Ok(()));
 97        }
 98
 99        let fetch_models_task = self.fetch_models(cx);
100        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
101    }
102}
103
104impl LmStudioLanguageModelProvider {
105    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
106        let this = Self {
107            http_client: http_client.clone(),
108            state: cx.new(|cx| {
109                let subscription = cx.observe_global::<SettingsStore>({
110                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
111                    move |this: &mut State, cx| {
112                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
113                        if &settings != new_settings {
114                            settings = new_settings.clone();
115                            this.restart_fetch_models_task(cx);
116                            cx.notify();
117                        }
118                    }
119                });
120
121                State {
122                    http_client,
123                    available_models: Default::default(),
124                    fetch_model_task: None,
125                    _subscription: subscription,
126                }
127            }),
128        };
129        this.state
130            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
131        this
132    }
133}
134
135impl LanguageModelProviderState for LmStudioLanguageModelProvider {
136    type ObservableEntity = State;
137
138    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
139        Some(self.state.clone())
140    }
141}
142
143impl LanguageModelProvider for LmStudioLanguageModelProvider {
144    fn id(&self) -> LanguageModelProviderId {
145        LanguageModelProviderId(PROVIDER_ID.into())
146    }
147
148    fn name(&self) -> LanguageModelProviderName {
149        LanguageModelProviderName(PROVIDER_NAME.into())
150    }
151
152    fn icon(&self) -> IconName {
153        IconName::AiLmStudio
154    }
155
156    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
157        self.provided_models(cx).into_iter().next()
158    }
159
160    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
161        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
162
163        // Add models from the LM Studio API
164        for model in self.state.read(cx).available_models.iter() {
165            models.insert(model.name.clone(), model.clone());
166        }
167
168        // Override with available models from settings
169        for model in AllLanguageModelSettings::get_global(cx)
170            .lmstudio
171            .available_models
172            .iter()
173        {
174            models.insert(
175                model.name.clone(),
176                lmstudio::Model {
177                    name: model.name.clone(),
178                    display_name: model.display_name.clone(),
179                    max_tokens: model.max_tokens,
180                },
181            );
182        }
183
184        models
185            .into_values()
186            .map(|model| {
187                Arc::new(LmStudioLanguageModel {
188                    id: LanguageModelId::from(model.name.clone()),
189                    model: model.clone(),
190                    http_client: self.http_client.clone(),
191                    request_limiter: RateLimiter::new(4),
192                }) as Arc<dyn LanguageModel>
193            })
194            .collect()
195    }
196
197    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
198        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
199        let http_client = self.http_client.clone();
200        let api_url = settings.api_url.clone();
201        let id = model.id().0.to_string();
202        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
203            .detach_and_log_err(cx);
204    }
205
206    fn is_authenticated(&self, cx: &App) -> bool {
207        self.state.read(cx).is_authenticated()
208    }
209
210    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
211        self.state.update(cx, |state, cx| state.authenticate(cx))
212    }
213
214    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
215        let state = self.state.clone();
216        cx.new(|cx| ConfigurationView::new(state, cx)).into()
217    }
218
219    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
220        self.state.update(cx, |state, cx| state.fetch_models(cx))
221    }
222}
223
224pub struct LmStudioLanguageModel {
225    id: LanguageModelId,
226    model: lmstudio::Model,
227    http_client: Arc<dyn HttpClient>,
228    request_limiter: RateLimiter,
229}
230
231impl LmStudioLanguageModel {
232    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
233        ChatCompletionRequest {
234            model: self.model.name.clone(),
235            messages: request
236                .messages
237                .into_iter()
238                .map(|msg| match msg.role {
239                    Role::User => ChatMessage::User {
240                        content: msg.string_contents(),
241                    },
242                    Role::Assistant => ChatMessage::Assistant {
243                        content: Some(msg.string_contents()),
244                        tool_calls: None,
245                    },
246                    Role::System => ChatMessage::System {
247                        content: msg.string_contents(),
248                    },
249                })
250                .collect(),
251            stream: true,
252            max_tokens: Some(-1),
253            stop: Some(request.stop),
254            temperature: request.temperature.or(Some(0.0)),
255            tools: vec![],
256        }
257    }
258}
259
260impl LanguageModel for LmStudioLanguageModel {
261    fn id(&self) -> LanguageModelId {
262        self.id.clone()
263    }
264
265    fn name(&self) -> LanguageModelName {
266        LanguageModelName::from(self.model.display_name().to_string())
267    }
268
269    fn provider_id(&self) -> LanguageModelProviderId {
270        LanguageModelProviderId(PROVIDER_ID.into())
271    }
272
273    fn provider_name(&self) -> LanguageModelProviderName {
274        LanguageModelProviderName(PROVIDER_NAME.into())
275    }
276
277    fn supports_tools(&self) -> bool {
278        false
279    }
280
281    fn telemetry_id(&self) -> String {
282        format!("lmstudio/{}", self.model.id())
283    }
284
285    fn max_token_count(&self) -> usize {
286        self.model.max_token_count()
287    }
288
289    fn count_tokens(
290        &self,
291        request: LanguageModelRequest,
292        _cx: &App,
293    ) -> BoxFuture<'static, Result<usize>> {
294        // Endpoint for this is coming soon. In the meantime, hacky estimation
295        let token_count = request
296            .messages
297            .iter()
298            .map(|msg| msg.string_contents().split_whitespace().count())
299            .sum::<usize>();
300
301        let estimated_tokens = (token_count as f64 * 0.75) as usize;
302        async move { Ok(estimated_tokens) }.boxed()
303    }
304
305    fn stream_completion(
306        &self,
307        request: LanguageModelRequest,
308        cx: &AsyncApp,
309    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
310        let request = self.to_lmstudio_request(request);
311
312        let http_client = self.http_client.clone();
313        let Ok(api_url) = cx.update(|cx| {
314            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
315            settings.api_url.clone()
316        }) else {
317            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
318        };
319
320        let future = self.request_limiter.stream(async move {
321            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
322            let stream = response
323                .filter_map(|response| async move {
324                    match response {
325                        Ok(fragment) => {
326                            // Skip empty deltas
327                            if fragment.choices[0].delta.is_object()
328                                && fragment.choices[0].delta.as_object().unwrap().is_empty()
329                            {
330                                return None;
331                            }
332
333                            // Try to parse the delta as ChatMessage
334                            if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
335                                fragment.choices[0].delta.clone(),
336                            ) {
337                                let content = match chat_message {
338                                    ChatMessage::User { content } => content,
339                                    ChatMessage::Assistant { content, .. } => {
340                                        content.unwrap_or_default()
341                                    }
342                                    ChatMessage::System { content } => content,
343                                };
344                                if !content.is_empty() {
345                                    Some(Ok(content))
346                                } else {
347                                    None
348                                }
349                            } else {
350                                None
351                            }
352                        }
353                        Err(error) => Some(Err(error)),
354                    }
355                })
356                .boxed();
357            Ok(stream)
358        });
359
360        async move {
361            Ok(future
362                .await?
363                .map(|result| result.map(LanguageModelCompletionEvent::Text))
364                .boxed())
365        }
366        .boxed()
367    }
368}
369
370struct ConfigurationView {
371    state: gpui::Entity<State>,
372    loading_models_task: Option<Task<()>>,
373}
374
375impl ConfigurationView {
376    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
377        let loading_models_task = Some(cx.spawn({
378            let state = state.clone();
379            async move |this, cx| {
380                if let Some(task) = state
381                    .update(cx, |state, cx| state.authenticate(cx))
382                    .log_err()
383                {
384                    task.await.log_err();
385                }
386                this.update(cx, |this, cx| {
387                    this.loading_models_task = None;
388                    cx.notify();
389                })
390                .log_err();
391            }
392        }));
393
394        Self {
395            state,
396            loading_models_task,
397        }
398    }
399
400    fn retry_connection(&self, cx: &mut App) {
401        self.state
402            .update(cx, |state, cx| state.fetch_models(cx))
403            .detach_and_log_err(cx);
404    }
405}
406
407impl Render for ConfigurationView {
408    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
409        let is_authenticated = self.state.read(cx).is_authenticated();
410
411        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
412
413        if self.loading_models_task.is_some() {
414            div().child(Label::new("Loading models...")).into_any()
415        } else {
416            v_flex()
417                .gap_2()
418                .child(
419                    v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
420                        List::new()
421                            .child(InstructionListItem::text_only(
422                                "LM Studio needs to be running with at least one model downloaded.",
423                            ))
424                            .child(InstructionListItem::text_only(
425                                "To get your first model, try running `lms get qwen2.5-coder-7b`",
426                            )),
427                    ),
428                )
429                .child(
430                    h_flex()
431                        .w_full()
432                        .justify_between()
433                        .gap_2()
434                        .child(
435                            h_flex()
436                                .w_full()
437                                .gap_2()
438                                .map(|this| {
439                                    if is_authenticated {
440                                        this.child(
441                                            Button::new("lmstudio-site", "LM Studio")
442                                                .style(ButtonStyle::Subtle)
443                                                .icon(IconName::ArrowUpRight)
444                                                .icon_size(IconSize::XSmall)
445                                                .icon_color(Color::Muted)
446                                                .on_click(move |_, _window, cx| {
447                                                    cx.open_url(LMSTUDIO_SITE)
448                                                })
449                                                .into_any_element(),
450                                        )
451                                    } else {
452                                        this.child(
453                                            Button::new(
454                                                "download_lmstudio_button",
455                                                "Download LM Studio",
456                                            )
457                                            .style(ButtonStyle::Subtle)
458                                            .icon(IconName::ArrowUpRight)
459                                            .icon_size(IconSize::XSmall)
460                                            .icon_color(Color::Muted)
461                                            .on_click(move |_, _window, cx| {
462                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
463                                            })
464                                            .into_any_element(),
465                                        )
466                                    }
467                                })
468                                .child(
469                                    Button::new("view-models", "Model Catalog")
470                                        .style(ButtonStyle::Subtle)
471                                        .icon(IconName::ArrowUpRight)
472                                        .icon_size(IconSize::XSmall)
473                                        .icon_color(Color::Muted)
474                                        .on_click(move |_, _window, cx| {
475                                            cx.open_url(LMSTUDIO_CATALOG_URL)
476                                        }),
477                                ),
478                        )
479                        .map(|this| {
480                            if is_authenticated {
481                                this.child(
482                                    ButtonLike::new("connected")
483                                        .disabled(true)
484                                        .cursor_style(gpui::CursorStyle::Arrow)
485                                        .child(
486                                            h_flex()
487                                                .gap_2()
488                                                .child(Indicator::dot().color(Color::Success))
489                                                .child(Label::new("Connected"))
490                                                .into_any_element(),
491                                        ),
492                                )
493                            } else {
494                                this.child(
495                                    Button::new("retry_lmstudio_models", "Connect")
496                                        .icon_position(IconPosition::Start)
497                                        .icon_size(IconSize::XSmall)
498                                        .icon(IconName::Play)
499                                        .on_click(cx.listener(move |this, _, _window, cx| {
500                                            this.retry_connection(cx)
501                                        })),
502                                )
503                            }
504                        }),
505                )
506                .into_any()
507        }
508    }
509}