lmstudio.rs

  1use anyhow::{anyhow, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use lmstudio::{
 12    get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage,
 13    ModelType,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{prelude::*, ButtonLike, Indicator};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23
 24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 27
 28const PROVIDER_ID: &str = "lmstudio";
 29const PROVIDER_NAME: &str = "LM Studio";
 30
 31#[derive(Default, Debug, Clone, PartialEq)]
 32pub struct LmStudioSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The model's context window size.
 44    pub max_tokens: usize,
 45}
 46
 47pub struct LmStudioLanguageModelProvider {
 48    http_client: Arc<dyn HttpClient>,
 49    state: gpui::Entity<State>,
 50}
 51
 52pub struct State {
 53    http_client: Arc<dyn HttpClient>,
 54    available_models: Vec<lmstudio::Model>,
 55    fetch_model_task: Option<Task<Result<()>>>,
 56    _subscription: Subscription,
 57}
 58
 59impl State {
 60    fn is_authenticated(&self) -> bool {
 61        !self.available_models.is_empty()
 62    }
 63
 64    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 65        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 66        let http_client = self.http_client.clone();
 67        let api_url = settings.api_url.clone();
 68
 69        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 70        cx.spawn(|this, mut cx| async move {
 71            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 72
 73            let mut models: Vec<lmstudio::Model> = models
 74                .into_iter()
 75                .filter(|model| model.r#type != ModelType::Embeddings)
 76                .map(|model| lmstudio::Model::new(&model.id, None, None))
 77                .collect();
 78
 79            models.sort_by(|a, b| a.name.cmp(&b.name));
 80
 81            this.update(&mut cx, |this, cx| {
 82                this.available_models = models;
 83                cx.notify();
 84            })
 85        })
 86    }
 87
 88    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 89        let task = self.fetch_models(cx);
 90        self.fetch_model_task.replace(task);
 91    }
 92
 93    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 94        if self.is_authenticated() {
 95            return Task::ready(Ok(()));
 96        }
 97
 98        let fetch_models_task = self.fetch_models(cx);
 99        cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
100    }
101}
102
103impl LmStudioLanguageModelProvider {
104    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105        let this = Self {
106            http_client: http_client.clone(),
107            state: cx.new(|cx| {
108                let subscription = cx.observe_global::<SettingsStore>({
109                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110                    move |this: &mut State, cx| {
111                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112                        if &settings != new_settings {
113                            settings = new_settings.clone();
114                            this.restart_fetch_models_task(cx);
115                            cx.notify();
116                        }
117                    }
118                });
119
120                State {
121                    http_client,
122                    available_models: Default::default(),
123                    fetch_model_task: None,
124                    _subscription: subscription,
125                }
126            }),
127        };
128        this.state
129            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130        this
131    }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135    type ObservableEntity = State;
136
137    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138        Some(self.state.clone())
139    }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143    fn id(&self) -> LanguageModelProviderId {
144        LanguageModelProviderId(PROVIDER_ID.into())
145    }
146
147    fn name(&self) -> LanguageModelProviderName {
148        LanguageModelProviderName(PROVIDER_NAME.into())
149    }
150
151    fn icon(&self) -> IconName {
152        IconName::AiLmStudio
153    }
154
155    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
156        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
157
158        // Add models from the LM Studio API
159        for model in self.state.read(cx).available_models.iter() {
160            models.insert(model.name.clone(), model.clone());
161        }
162
163        // Override with available models from settings
164        for model in AllLanguageModelSettings::get_global(cx)
165            .lmstudio
166            .available_models
167            .iter()
168        {
169            models.insert(
170                model.name.clone(),
171                lmstudio::Model {
172                    name: model.name.clone(),
173                    display_name: model.display_name.clone(),
174                    max_tokens: model.max_tokens,
175                },
176            );
177        }
178
179        models
180            .into_values()
181            .map(|model| {
182                Arc::new(LmStudioLanguageModel {
183                    id: LanguageModelId::from(model.name.clone()),
184                    model: model.clone(),
185                    http_client: self.http_client.clone(),
186                    request_limiter: RateLimiter::new(4),
187                }) as Arc<dyn LanguageModel>
188            })
189            .collect()
190    }
191
192    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
193        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
194        let http_client = self.http_client.clone();
195        let api_url = settings.api_url.clone();
196        let id = model.id().0.to_string();
197        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
198            .detach_and_log_err(cx);
199    }
200
201    fn is_authenticated(&self, cx: &App) -> bool {
202        self.state.read(cx).is_authenticated()
203    }
204
205    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
206        self.state.update(cx, |state, cx| state.authenticate(cx))
207    }
208
209    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
210        let state = self.state.clone();
211        cx.new(|cx| ConfigurationView::new(state, cx)).into()
212    }
213
214    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
215        self.state.update(cx, |state, cx| state.fetch_models(cx))
216    }
217}
218
219pub struct LmStudioLanguageModel {
220    id: LanguageModelId,
221    model: lmstudio::Model,
222    http_client: Arc<dyn HttpClient>,
223    request_limiter: RateLimiter,
224}
225
226impl LmStudioLanguageModel {
227    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
228        ChatCompletionRequest {
229            model: self.model.name.clone(),
230            messages: request
231                .messages
232                .into_iter()
233                .map(|msg| match msg.role {
234                    Role::User => ChatMessage::User {
235                        content: msg.string_contents(),
236                    },
237                    Role::Assistant => ChatMessage::Assistant {
238                        content: Some(msg.string_contents()),
239                        tool_calls: None,
240                    },
241                    Role::System => ChatMessage::System {
242                        content: msg.string_contents(),
243                    },
244                })
245                .collect(),
246            stream: true,
247            max_tokens: Some(-1),
248            stop: Some(request.stop),
249            temperature: request.temperature.or(Some(0.0)),
250            tools: vec![],
251        }
252    }
253}
254
255impl LanguageModel for LmStudioLanguageModel {
256    fn id(&self) -> LanguageModelId {
257        self.id.clone()
258    }
259
260    fn name(&self) -> LanguageModelName {
261        LanguageModelName::from(self.model.display_name().to_string())
262    }
263
264    fn provider_id(&self) -> LanguageModelProviderId {
265        LanguageModelProviderId(PROVIDER_ID.into())
266    }
267
268    fn provider_name(&self) -> LanguageModelProviderName {
269        LanguageModelProviderName(PROVIDER_NAME.into())
270    }
271
272    fn telemetry_id(&self) -> String {
273        format!("lmstudio/{}", self.model.id())
274    }
275
276    fn max_token_count(&self) -> usize {
277        self.model.max_token_count()
278    }
279
280    fn count_tokens(
281        &self,
282        request: LanguageModelRequest,
283        _cx: &App,
284    ) -> BoxFuture<'static, Result<usize>> {
285        // Endpoint for this is coming soon. In the meantime, hacky estimation
286        let token_count = request
287            .messages
288            .iter()
289            .map(|msg| msg.string_contents().split_whitespace().count())
290            .sum::<usize>();
291
292        let estimated_tokens = (token_count as f64 * 0.75) as usize;
293        async move { Ok(estimated_tokens) }.boxed()
294    }
295
296    fn stream_completion(
297        &self,
298        request: LanguageModelRequest,
299        cx: &AsyncApp,
300    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
301        let request = self.to_lmstudio_request(request);
302
303        let http_client = self.http_client.clone();
304        let Ok(api_url) = cx.update(|cx| {
305            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
306            settings.api_url.clone()
307        }) else {
308            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
309        };
310
311        let future = self.request_limiter.stream(async move {
312            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
313            let stream = response
314                .filter_map(|response| async move {
315                    match response {
316                        Ok(fragment) => {
317                            // Skip empty deltas
318                            if fragment.choices[0].delta.is_object()
319                                && fragment.choices[0].delta.as_object().unwrap().is_empty()
320                            {
321                                return None;
322                            }
323
324                            // Try to parse the delta as ChatMessage
325                            if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
326                                fragment.choices[0].delta.clone(),
327                            ) {
328                                let content = match chat_message {
329                                    ChatMessage::User { content } => content,
330                                    ChatMessage::Assistant { content, .. } => {
331                                        content.unwrap_or_default()
332                                    }
333                                    ChatMessage::System { content } => content,
334                                };
335                                if !content.is_empty() {
336                                    Some(Ok(content))
337                                } else {
338                                    None
339                                }
340                            } else {
341                                None
342                            }
343                        }
344                        Err(error) => Some(Err(error)),
345                    }
346                })
347                .boxed();
348            Ok(stream)
349        });
350
351        async move {
352            Ok(future
353                .await?
354                .map(|result| result.map(LanguageModelCompletionEvent::Text))
355                .boxed())
356        }
357        .boxed()
358    }
359
360    fn use_any_tool(
361        &self,
362        _request: LanguageModelRequest,
363        _tool_name: String,
364        _tool_description: String,
365        _schema: serde_json::Value,
366        _cx: &AsyncApp,
367    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
368        async move { Ok(futures::stream::empty().boxed()) }.boxed()
369    }
370}
371
372struct ConfigurationView {
373    state: gpui::Entity<State>,
374    loading_models_task: Option<Task<()>>,
375}
376
377impl ConfigurationView {
378    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
379        let loading_models_task = Some(cx.spawn({
380            let state = state.clone();
381            |this, mut cx| async move {
382                if let Some(task) = state
383                    .update(&mut cx, |state, cx| state.authenticate(cx))
384                    .log_err()
385                {
386                    task.await.log_err();
387                }
388                this.update(&mut cx, |this, cx| {
389                    this.loading_models_task = None;
390                    cx.notify();
391                })
392                .log_err();
393            }
394        }));
395
396        Self {
397            state,
398            loading_models_task,
399        }
400    }
401
402    fn retry_connection(&self, cx: &mut App) {
403        self.state
404            .update(cx, |state, cx| state.fetch_models(cx))
405            .detach_and_log_err(cx);
406    }
407}
408
409impl Render for ConfigurationView {
410    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
411        let is_authenticated = self.state.read(cx).is_authenticated();
412
413        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
414        let lmstudio_reqs =
415            "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
416
417        let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
418
419        if self.loading_models_task.is_some() {
420            div().child(Label::new("Loading models...")).into_any()
421        } else {
422            v_flex()
423                .size_full()
424                .gap_3()
425                .child(
426                    v_flex()
427                        .size_full()
428                        .gap_2()
429                        .p_1()
430                        .child(Label::new(lmstudio_intro))
431                        .child(Label::new(lmstudio_reqs))
432                        .child(
433                            h_flex()
434                                .gap_0p5()
435                                .child(Label::new("To get your first model, try running"))
436                                .child(
437                                    div()
438                                        .bg(inline_code_bg)
439                                        .px_1p5()
440                                        .rounded_md()
441                                        .child(Label::new("lms get qwen2.5-coder-7b")),
442                                ),
443                        ),
444                )
445                .child(
446                    h_flex()
447                        .w_full()
448                        .pt_2()
449                        .justify_between()
450                        .gap_2()
451                        .child(
452                            h_flex()
453                                .w_full()
454                                .gap_2()
455                                .map(|this| {
456                                    if is_authenticated {
457                                        this.child(
458                                            Button::new("lmstudio-site", "LM Studio")
459                                                .style(ButtonStyle::Subtle)
460                                                .icon(IconName::ArrowUpRight)
461                                                .icon_size(IconSize::XSmall)
462                                                .icon_color(Color::Muted)
463                                                .on_click(move |_, _window, cx| {
464                                                    cx.open_url(LMSTUDIO_SITE)
465                                                })
466                                                .into_any_element(),
467                                        )
468                                    } else {
469                                        this.child(
470                                            Button::new(
471                                                "download_lmstudio_button",
472                                                "Download LM Studio",
473                                            )
474                                            .style(ButtonStyle::Subtle)
475                                            .icon(IconName::ArrowUpRight)
476                                            .icon_size(IconSize::XSmall)
477                                            .icon_color(Color::Muted)
478                                            .on_click(move |_, _window, cx| {
479                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
480                                            })
481                                            .into_any_element(),
482                                        )
483                                    }
484                                })
485                                .child(
486                                    Button::new("view-models", "Model Catalog")
487                                        .style(ButtonStyle::Subtle)
488                                        .icon(IconName::ArrowUpRight)
489                                        .icon_size(IconSize::XSmall)
490                                        .icon_color(Color::Muted)
491                                        .on_click(move |_, _window, cx| {
492                                            cx.open_url(LMSTUDIO_CATALOG_URL)
493                                        }),
494                                ),
495                        )
496                        .child(if is_authenticated {
497                            // This is only a button to ensure the spacing is correct
498                            // it should stay disabled
499                            ButtonLike::new("connected")
500                                .disabled(true)
501                                // Since this won't ever be clickable, we can use the arrow cursor
502                                .cursor_style(gpui::CursorStyle::Arrow)
503                                .child(
504                                    h_flex()
505                                        .gap_2()
506                                        .child(Indicator::dot().color(Color::Success))
507                                        .child(Label::new("Connected"))
508                                        .into_any_element(),
509                                )
510                                .into_any_element()
511                        } else {
512                            Button::new("retry_lmstudio_models", "Connect")
513                                .icon_position(IconPosition::Start)
514                                .icon(IconName::ArrowCircle)
515                                .on_click(cx.listener(move |this, _, _window, cx| {
516                                    this.retry_connection(cx)
517                                }))
518                                .into_any_element()
519                        }),
520                )
521                .into_any()
522        }
523    }
524}