lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use lmstudio::{
 12    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 13    stream_chat_completion,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{ButtonLike, Indicator, List, prelude::*};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23use crate::ui::InstructionListItem;
 24
 25const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 26const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 27const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 28
 29const PROVIDER_ID: &str = "lmstudio";
 30const PROVIDER_NAME: &str = "LM Studio";
 31
 32#[derive(Default, Debug, Clone, PartialEq)]
 33pub struct LmStudioSettings {
 34    pub api_url: String,
 35    pub available_models: Vec<AvailableModel>,
 36}
 37
 38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 39pub struct AvailableModel {
 40    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 41    pub name: String,
 42    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 43    pub display_name: Option<String>,
 44    /// The model's context window size.
 45    pub max_tokens: usize,
 46}
 47
 48pub struct LmStudioLanguageModelProvider {
 49    http_client: Arc<dyn HttpClient>,
 50    state: gpui::Entity<State>,
 51}
 52
 53pub struct State {
 54    http_client: Arc<dyn HttpClient>,
 55    available_models: Vec<lmstudio::Model>,
 56    fetch_model_task: Option<Task<Result<()>>>,
 57    _subscription: Subscription,
 58}
 59
 60impl State {
 61    fn is_authenticated(&self) -> bool {
 62        !self.available_models.is_empty()
 63    }
 64
 65    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 66        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 67        let http_client = self.http_client.clone();
 68        let api_url = settings.api_url.clone();
 69
 70        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 71        cx.spawn(async move |this, cx| {
 72            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 73
 74            let mut models: Vec<lmstudio::Model> = models
 75                .into_iter()
 76                .filter(|model| model.r#type != ModelType::Embeddings)
 77                .map(|model| lmstudio::Model::new(&model.id, None, None))
 78                .collect();
 79
 80            models.sort_by(|a, b| a.name.cmp(&b.name));
 81
 82            this.update(cx, |this, cx| {
 83                this.available_models = models;
 84                cx.notify();
 85            })
 86        })
 87    }
 88
 89    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 90        let task = self.fetch_models(cx);
 91        self.fetch_model_task.replace(task);
 92    }
 93
 94    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 95        if self.is_authenticated() {
 96            return Task::ready(Ok(()));
 97        }
 98
 99        let fetch_models_task = self.fetch_models(cx);
100        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
101    }
102}
103
104impl LmStudioLanguageModelProvider {
105    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
106        let this = Self {
107            http_client: http_client.clone(),
108            state: cx.new(|cx| {
109                let subscription = cx.observe_global::<SettingsStore>({
110                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
111                    move |this: &mut State, cx| {
112                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
113                        if &settings != new_settings {
114                            settings = new_settings.clone();
115                            this.restart_fetch_models_task(cx);
116                            cx.notify();
117                        }
118                    }
119                });
120
121                State {
122                    http_client,
123                    available_models: Default::default(),
124                    fetch_model_task: None,
125                    _subscription: subscription,
126                }
127            }),
128        };
129        this.state
130            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
131        this
132    }
133}
134
135impl LanguageModelProviderState for LmStudioLanguageModelProvider {
136    type ObservableEntity = State;
137
138    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
139        Some(self.state.clone())
140    }
141}
142
143impl LanguageModelProvider for LmStudioLanguageModelProvider {
144    fn id(&self) -> LanguageModelProviderId {
145        LanguageModelProviderId(PROVIDER_ID.into())
146    }
147
148    fn name(&self) -> LanguageModelProviderName {
149        LanguageModelProviderName(PROVIDER_NAME.into())
150    }
151
152    fn icon(&self) -> IconName {
153        IconName::AiLmStudio
154    }
155
156    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
157        self.provided_models(cx).into_iter().next()
158    }
159
160    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
161        self.default_model(cx)
162    }
163
164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
166
167        // Add models from the LM Studio API
168        for model in self.state.read(cx).available_models.iter() {
169            models.insert(model.name.clone(), model.clone());
170        }
171
172        // Override with available models from settings
173        for model in AllLanguageModelSettings::get_global(cx)
174            .lmstudio
175            .available_models
176            .iter()
177        {
178            models.insert(
179                model.name.clone(),
180                lmstudio::Model {
181                    name: model.name.clone(),
182                    display_name: model.display_name.clone(),
183                    max_tokens: model.max_tokens,
184                },
185            );
186        }
187
188        models
189            .into_values()
190            .map(|model| {
191                Arc::new(LmStudioLanguageModel {
192                    id: LanguageModelId::from(model.name.clone()),
193                    model: model.clone(),
194                    http_client: self.http_client.clone(),
195                    request_limiter: RateLimiter::new(4),
196                }) as Arc<dyn LanguageModel>
197            })
198            .collect()
199    }
200
201    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
202        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
203        let http_client = self.http_client.clone();
204        let api_url = settings.api_url.clone();
205        let id = model.id().0.to_string();
206        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
207            .detach_and_log_err(cx);
208    }
209
210    fn is_authenticated(&self, cx: &App) -> bool {
211        self.state.read(cx).is_authenticated()
212    }
213
214    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
215        self.state.update(cx, |state, cx| state.authenticate(cx))
216    }
217
218    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
219        let state = self.state.clone();
220        cx.new(|cx| ConfigurationView::new(state, cx)).into()
221    }
222
223    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
224        self.state.update(cx, |state, cx| state.fetch_models(cx))
225    }
226}
227
228pub struct LmStudioLanguageModel {
229    id: LanguageModelId,
230    model: lmstudio::Model,
231    http_client: Arc<dyn HttpClient>,
232    request_limiter: RateLimiter,
233}
234
235impl LmStudioLanguageModel {
236    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
237        ChatCompletionRequest {
238            model: self.model.name.clone(),
239            messages: request
240                .messages
241                .into_iter()
242                .map(|msg| match msg.role {
243                    Role::User => ChatMessage::User {
244                        content: msg.string_contents(),
245                    },
246                    Role::Assistant => ChatMessage::Assistant {
247                        content: Some(msg.string_contents()),
248                        tool_calls: None,
249                    },
250                    Role::System => ChatMessage::System {
251                        content: msg.string_contents(),
252                    },
253                })
254                .collect(),
255            stream: true,
256            max_tokens: Some(-1),
257            stop: Some(request.stop),
258            temperature: request.temperature.or(Some(0.0)),
259            tools: vec![],
260        }
261    }
262}
263
264impl LanguageModel for LmStudioLanguageModel {
265    fn id(&self) -> LanguageModelId {
266        self.id.clone()
267    }
268
269    fn name(&self) -> LanguageModelName {
270        LanguageModelName::from(self.model.display_name().to_string())
271    }
272
273    fn provider_id(&self) -> LanguageModelProviderId {
274        LanguageModelProviderId(PROVIDER_ID.into())
275    }
276
277    fn provider_name(&self) -> LanguageModelProviderName {
278        LanguageModelProviderName(PROVIDER_NAME.into())
279    }
280
281    fn supports_tools(&self) -> bool {
282        false
283    }
284
285    fn telemetry_id(&self) -> String {
286        format!("lmstudio/{}", self.model.id())
287    }
288
289    fn max_token_count(&self) -> usize {
290        self.model.max_token_count()
291    }
292
293    fn count_tokens(
294        &self,
295        request: LanguageModelRequest,
296        _cx: &App,
297    ) -> BoxFuture<'static, Result<usize>> {
298        // Endpoint for this is coming soon. In the meantime, hacky estimation
299        let token_count = request
300            .messages
301            .iter()
302            .map(|msg| msg.string_contents().split_whitespace().count())
303            .sum::<usize>();
304
305        let estimated_tokens = (token_count as f64 * 0.75) as usize;
306        async move { Ok(estimated_tokens) }.boxed()
307    }
308
309    fn stream_completion(
310        &self,
311        request: LanguageModelRequest,
312        cx: &AsyncApp,
313    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
314        let request = self.to_lmstudio_request(request);
315
316        let http_client = self.http_client.clone();
317        let Ok(api_url) = cx.update(|cx| {
318            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
319            settings.api_url.clone()
320        }) else {
321            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
322        };
323
324        let future = self.request_limiter.stream(async move {
325            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
326            let stream = response
327                .filter_map(|response| async move {
328                    match response {
329                        Ok(fragment) => {
330                            // Skip empty deltas
331                            if fragment.choices[0].delta.is_object()
332                                && fragment.choices[0].delta.as_object().unwrap().is_empty()
333                            {
334                                return None;
335                            }
336
337                            // Try to parse the delta as ChatMessage
338                            if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
339                                fragment.choices[0].delta.clone(),
340                            ) {
341                                let content = match chat_message {
342                                    ChatMessage::User { content } => content,
343                                    ChatMessage::Assistant { content, .. } => {
344                                        content.unwrap_or_default()
345                                    }
346                                    ChatMessage::System { content } => content,
347                                };
348                                if !content.is_empty() {
349                                    Some(Ok(content))
350                                } else {
351                                    None
352                                }
353                            } else {
354                                None
355                            }
356                        }
357                        Err(error) => Some(Err(error)),
358                    }
359                })
360                .boxed();
361            Ok(stream)
362        });
363
364        async move {
365            Ok(future
366                .await?
367                .map(|result| result.map(LanguageModelCompletionEvent::Text))
368                .boxed())
369        }
370        .boxed()
371    }
372}
373
374struct ConfigurationView {
375    state: gpui::Entity<State>,
376    loading_models_task: Option<Task<()>>,
377}
378
379impl ConfigurationView {
380    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
381        let loading_models_task = Some(cx.spawn({
382            let state = state.clone();
383            async move |this, cx| {
384                if let Some(task) = state
385                    .update(cx, |state, cx| state.authenticate(cx))
386                    .log_err()
387                {
388                    task.await.log_err();
389                }
390                this.update(cx, |this, cx| {
391                    this.loading_models_task = None;
392                    cx.notify();
393                })
394                .log_err();
395            }
396        }));
397
398        Self {
399            state,
400            loading_models_task,
401        }
402    }
403
404    fn retry_connection(&self, cx: &mut App) {
405        self.state
406            .update(cx, |state, cx| state.fetch_models(cx))
407            .detach_and_log_err(cx);
408    }
409}
410
411impl Render for ConfigurationView {
412    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
413        let is_authenticated = self.state.read(cx).is_authenticated();
414
415        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
416
417        if self.loading_models_task.is_some() {
418            div().child(Label::new("Loading models...")).into_any()
419        } else {
420            v_flex()
421                .gap_2()
422                .child(
423                    v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
424                        List::new()
425                            .child(InstructionListItem::text_only(
426                                "LM Studio needs to be running with at least one model downloaded.",
427                            ))
428                            .child(InstructionListItem::text_only(
429                                "To get your first model, try running `lms get qwen2.5-coder-7b`",
430                            )),
431                    ),
432                )
433                .child(
434                    h_flex()
435                        .w_full()
436                        .justify_between()
437                        .gap_2()
438                        .child(
439                            h_flex()
440                                .w_full()
441                                .gap_2()
442                                .map(|this| {
443                                    if is_authenticated {
444                                        this.child(
445                                            Button::new("lmstudio-site", "LM Studio")
446                                                .style(ButtonStyle::Subtle)
447                                                .icon(IconName::ArrowUpRight)
448                                                .icon_size(IconSize::XSmall)
449                                                .icon_color(Color::Muted)
450                                                .on_click(move |_, _window, cx| {
451                                                    cx.open_url(LMSTUDIO_SITE)
452                                                })
453                                                .into_any_element(),
454                                        )
455                                    } else {
456                                        this.child(
457                                            Button::new(
458                                                "download_lmstudio_button",
459                                                "Download LM Studio",
460                                            )
461                                            .style(ButtonStyle::Subtle)
462                                            .icon(IconName::ArrowUpRight)
463                                            .icon_size(IconSize::XSmall)
464                                            .icon_color(Color::Muted)
465                                            .on_click(move |_, _window, cx| {
466                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
467                                            })
468                                            .into_any_element(),
469                                        )
470                                    }
471                                })
472                                .child(
473                                    Button::new("view-models", "Model Catalog")
474                                        .style(ButtonStyle::Subtle)
475                                        .icon(IconName::ArrowUpRight)
476                                        .icon_size(IconSize::XSmall)
477                                        .icon_color(Color::Muted)
478                                        .on_click(move |_, _window, cx| {
479                                            cx.open_url(LMSTUDIO_CATALOG_URL)
480                                        }),
481                                ),
482                        )
483                        .map(|this| {
484                            if is_authenticated {
485                                this.child(
486                                    ButtonLike::new("connected")
487                                        .disabled(true)
488                                        .cursor_style(gpui::CursorStyle::Arrow)
489                                        .child(
490                                            h_flex()
491                                                .gap_2()
492                                                .child(Indicator::dot().color(Color::Success))
493                                                .child(Label::new("Connected"))
494                                                .into_any_element(),
495                                        ),
496                                )
497                            } else {
498                                this.child(
499                                    Button::new("retry_lmstudio_models", "Connect")
500                                        .icon_position(IconPosition::Start)
501                                        .icon_size(IconSize::XSmall)
502                                        .icon(IconName::Play)
503                                        .on_click(cx.listener(move |this, _, _window, cx| {
504                                            this.retry_connection(cx)
505                                        })),
506                                )
507                            }
508                        }),
509                )
510                .into_any()
511        }
512    }
513}