lmstudio.rs

  1use anyhow::{anyhow, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use lmstudio::{
 12    get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage,
 13    ModelType,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{prelude::*, ButtonLike, Indicator};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23
 24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 27
 28const PROVIDER_ID: &str = "lmstudio";
 29const PROVIDER_NAME: &str = "LM Studio";
 30
 31#[derive(Default, Debug, Clone, PartialEq)]
 32pub struct LmStudioSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The model's context window size.
 44    pub max_tokens: usize,
 45}
 46
 47pub struct LmStudioLanguageModelProvider {
 48    http_client: Arc<dyn HttpClient>,
 49    state: gpui::Entity<State>,
 50}
 51
 52pub struct State {
 53    http_client: Arc<dyn HttpClient>,
 54    available_models: Vec<lmstudio::Model>,
 55    fetch_model_task: Option<Task<Result<()>>>,
 56    _subscription: Subscription,
 57}
 58
 59impl State {
 60    fn is_authenticated(&self) -> bool {
 61        !self.available_models.is_empty()
 62    }
 63
 64    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 65        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 66        let http_client = self.http_client.clone();
 67        let api_url = settings.api_url.clone();
 68
 69        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 70        cx.spawn(async move |this, cx| {
 71            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 72
 73            let mut models: Vec<lmstudio::Model> = models
 74                .into_iter()
 75                .filter(|model| model.r#type != ModelType::Embeddings)
 76                .map(|model| lmstudio::Model::new(&model.id, None, None))
 77                .collect();
 78
 79            models.sort_by(|a, b| a.name.cmp(&b.name));
 80
 81            this.update(cx, |this, cx| {
 82                this.available_models = models;
 83                cx.notify();
 84            })
 85        })
 86    }
 87
 88    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 89        let task = self.fetch_models(cx);
 90        self.fetch_model_task.replace(task);
 91    }
 92
 93    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 94        if self.is_authenticated() {
 95            return Task::ready(Ok(()));
 96        }
 97
 98        let fetch_models_task = self.fetch_models(cx);
 99        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
100    }
101}
102
103impl LmStudioLanguageModelProvider {
104    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105        let this = Self {
106            http_client: http_client.clone(),
107            state: cx.new(|cx| {
108                let subscription = cx.observe_global::<SettingsStore>({
109                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110                    move |this: &mut State, cx| {
111                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112                        if &settings != new_settings {
113                            settings = new_settings.clone();
114                            this.restart_fetch_models_task(cx);
115                            cx.notify();
116                        }
117                    }
118                });
119
120                State {
121                    http_client,
122                    available_models: Default::default(),
123                    fetch_model_task: None,
124                    _subscription: subscription,
125                }
126            }),
127        };
128        this.state
129            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130        this
131    }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135    type ObservableEntity = State;
136
137    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138        Some(self.state.clone())
139    }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143    fn id(&self) -> LanguageModelProviderId {
144        LanguageModelProviderId(PROVIDER_ID.into())
145    }
146
147    fn name(&self) -> LanguageModelProviderName {
148        LanguageModelProviderName(PROVIDER_NAME.into())
149    }
150
151    fn icon(&self) -> IconName {
152        IconName::AiLmStudio
153    }
154
155    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
156        self.provided_models(cx).into_iter().next()
157    }
158
159    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
161
162        // Add models from the LM Studio API
163        for model in self.state.read(cx).available_models.iter() {
164            models.insert(model.name.clone(), model.clone());
165        }
166
167        // Override with available models from settings
168        for model in AllLanguageModelSettings::get_global(cx)
169            .lmstudio
170            .available_models
171            .iter()
172        {
173            models.insert(
174                model.name.clone(),
175                lmstudio::Model {
176                    name: model.name.clone(),
177                    display_name: model.display_name.clone(),
178                    max_tokens: model.max_tokens,
179                },
180            );
181        }
182
183        models
184            .into_values()
185            .map(|model| {
186                Arc::new(LmStudioLanguageModel {
187                    id: LanguageModelId::from(model.name.clone()),
188                    model: model.clone(),
189                    http_client: self.http_client.clone(),
190                    request_limiter: RateLimiter::new(4),
191                }) as Arc<dyn LanguageModel>
192            })
193            .collect()
194    }
195
196    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
197        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
198        let http_client = self.http_client.clone();
199        let api_url = settings.api_url.clone();
200        let id = model.id().0.to_string();
201        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
202            .detach_and_log_err(cx);
203    }
204
205    fn is_authenticated(&self, cx: &App) -> bool {
206        self.state.read(cx).is_authenticated()
207    }
208
209    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
210        self.state.update(cx, |state, cx| state.authenticate(cx))
211    }
212
213    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
214        let state = self.state.clone();
215        cx.new(|cx| ConfigurationView::new(state, cx)).into()
216    }
217
218    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219        self.state.update(cx, |state, cx| state.fetch_models(cx))
220    }
221}
222
223pub struct LmStudioLanguageModel {
224    id: LanguageModelId,
225    model: lmstudio::Model,
226    http_client: Arc<dyn HttpClient>,
227    request_limiter: RateLimiter,
228}
229
230impl LmStudioLanguageModel {
231    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
232        ChatCompletionRequest {
233            model: self.model.name.clone(),
234            messages: request
235                .messages
236                .into_iter()
237                .map(|msg| match msg.role {
238                    Role::User => ChatMessage::User {
239                        content: msg.string_contents(),
240                    },
241                    Role::Assistant => ChatMessage::Assistant {
242                        content: Some(msg.string_contents()),
243                        tool_calls: None,
244                    },
245                    Role::System => ChatMessage::System {
246                        content: msg.string_contents(),
247                    },
248                })
249                .collect(),
250            stream: true,
251            max_tokens: Some(-1),
252            stop: Some(request.stop),
253            temperature: request.temperature.or(Some(0.0)),
254            tools: vec![],
255        }
256    }
257}
258
259impl LanguageModel for LmStudioLanguageModel {
260    fn id(&self) -> LanguageModelId {
261        self.id.clone()
262    }
263
264    fn name(&self) -> LanguageModelName {
265        LanguageModelName::from(self.model.display_name().to_string())
266    }
267
268    fn provider_id(&self) -> LanguageModelProviderId {
269        LanguageModelProviderId(PROVIDER_ID.into())
270    }
271
272    fn provider_name(&self) -> LanguageModelProviderName {
273        LanguageModelProviderName(PROVIDER_NAME.into())
274    }
275
276    fn telemetry_id(&self) -> String {
277        format!("lmstudio/{}", self.model.id())
278    }
279
280    fn max_token_count(&self) -> usize {
281        self.model.max_token_count()
282    }
283
284    fn count_tokens(
285        &self,
286        request: LanguageModelRequest,
287        _cx: &App,
288    ) -> BoxFuture<'static, Result<usize>> {
289        // Endpoint for this is coming soon. In the meantime, hacky estimation
290        let token_count = request
291            .messages
292            .iter()
293            .map(|msg| msg.string_contents().split_whitespace().count())
294            .sum::<usize>();
295
296        let estimated_tokens = (token_count as f64 * 0.75) as usize;
297        async move { Ok(estimated_tokens) }.boxed()
298    }
299
300    fn stream_completion(
301        &self,
302        request: LanguageModelRequest,
303        cx: &AsyncApp,
304    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
305        let request = self.to_lmstudio_request(request);
306
307        let http_client = self.http_client.clone();
308        let Ok(api_url) = cx.update(|cx| {
309            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
310            settings.api_url.clone()
311        }) else {
312            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
313        };
314
315        let future = self.request_limiter.stream(async move {
316            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
317            let stream = response
318                .filter_map(|response| async move {
319                    match response {
320                        Ok(fragment) => {
321                            // Skip empty deltas
322                            if fragment.choices[0].delta.is_object()
323                                && fragment.choices[0].delta.as_object().unwrap().is_empty()
324                            {
325                                return None;
326                            }
327
328                            // Try to parse the delta as ChatMessage
329                            if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
330                                fragment.choices[0].delta.clone(),
331                            ) {
332                                let content = match chat_message {
333                                    ChatMessage::User { content } => content,
334                                    ChatMessage::Assistant { content, .. } => {
335                                        content.unwrap_or_default()
336                                    }
337                                    ChatMessage::System { content } => content,
338                                };
339                                if !content.is_empty() {
340                                    Some(Ok(content))
341                                } else {
342                                    None
343                                }
344                            } else {
345                                None
346                            }
347                        }
348                        Err(error) => Some(Err(error)),
349                    }
350                })
351                .boxed();
352            Ok(stream)
353        });
354
355        async move {
356            Ok(future
357                .await?
358                .map(|result| result.map(LanguageModelCompletionEvent::Text))
359                .boxed())
360        }
361        .boxed()
362    }
363
364    fn use_any_tool(
365        &self,
366        _request: LanguageModelRequest,
367        _tool_name: String,
368        _tool_description: String,
369        _schema: serde_json::Value,
370        _cx: &AsyncApp,
371    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
372        async move { Ok(futures::stream::empty().boxed()) }.boxed()
373    }
374}
375
376struct ConfigurationView {
377    state: gpui::Entity<State>,
378    loading_models_task: Option<Task<()>>,
379}
380
381impl ConfigurationView {
382    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
383        let loading_models_task = Some(cx.spawn({
384            let state = state.clone();
385            async move |this, cx| {
386                if let Some(task) = state
387                    .update(cx, |state, cx| state.authenticate(cx))
388                    .log_err()
389                {
390                    task.await.log_err();
391                }
392                this.update(cx, |this, cx| {
393                    this.loading_models_task = None;
394                    cx.notify();
395                })
396                .log_err();
397            }
398        }));
399
400        Self {
401            state,
402            loading_models_task,
403        }
404    }
405
406    fn retry_connection(&self, cx: &mut App) {
407        self.state
408            .update(cx, |state, cx| state.fetch_models(cx))
409            .detach_and_log_err(cx);
410    }
411}
412
413impl Render for ConfigurationView {
414    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
415        let is_authenticated = self.state.read(cx).is_authenticated();
416
417        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
418        let lmstudio_reqs = "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
419
420        let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
421
422        if self.loading_models_task.is_some() {
423            div().child(Label::new("Loading models...")).into_any()
424        } else {
425            v_flex()
426                .size_full()
427                .gap_3()
428                .child(
429                    v_flex()
430                        .size_full()
431                        .gap_2()
432                        .p_1()
433                        .child(Label::new(lmstudio_intro))
434                        .child(Label::new(lmstudio_reqs))
435                        .child(
436                            h_flex()
437                                .gap_0p5()
438                                .child(Label::new("To get your first model, try running"))
439                                .child(
440                                    div()
441                                        .bg(inline_code_bg)
442                                        .px_1p5()
443                                        .rounded_sm()
444                                        .child(Label::new("lms get qwen2.5-coder-7b")),
445                                ),
446                        ),
447                )
448                .child(
449                    h_flex()
450                        .w_full()
451                        .pt_2()
452                        .justify_between()
453                        .gap_2()
454                        .child(
455                            h_flex()
456                                .w_full()
457                                .gap_2()
458                                .map(|this| {
459                                    if is_authenticated {
460                                        this.child(
461                                            Button::new("lmstudio-site", "LM Studio")
462                                                .style(ButtonStyle::Subtle)
463                                                .icon(IconName::ArrowUpRight)
464                                                .icon_size(IconSize::XSmall)
465                                                .icon_color(Color::Muted)
466                                                .on_click(move |_, _window, cx| {
467                                                    cx.open_url(LMSTUDIO_SITE)
468                                                })
469                                                .into_any_element(),
470                                        )
471                                    } else {
472                                        this.child(
473                                            Button::new(
474                                                "download_lmstudio_button",
475                                                "Download LM Studio",
476                                            )
477                                            .style(ButtonStyle::Subtle)
478                                            .icon(IconName::ArrowUpRight)
479                                            .icon_size(IconSize::XSmall)
480                                            .icon_color(Color::Muted)
481                                            .on_click(move |_, _window, cx| {
482                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
483                                            })
484                                            .into_any_element(),
485                                        )
486                                    }
487                                })
488                                .child(
489                                    Button::new("view-models", "Model Catalog")
490                                        .style(ButtonStyle::Subtle)
491                                        .icon(IconName::ArrowUpRight)
492                                        .icon_size(IconSize::XSmall)
493                                        .icon_color(Color::Muted)
494                                        .on_click(move |_, _window, cx| {
495                                            cx.open_url(LMSTUDIO_CATALOG_URL)
496                                        }),
497                                ),
498                        )
499                        .child(if is_authenticated {
500                            // This is only a button to ensure the spacing is correct
501                            // it should stay disabled
502                            ButtonLike::new("connected")
503                                .disabled(true)
504                                // Since this won't ever be clickable, we can use the arrow cursor
505                                .cursor_style(gpui::CursorStyle::Arrow)
506                                .child(
507                                    h_flex()
508                                        .gap_2()
509                                        .child(Indicator::dot().color(Color::Success))
510                                        .child(Label::new("Connected"))
511                                        .into_any_element(),
512                                )
513                                .into_any_element()
514                        } else {
515                            Button::new("retry_lmstudio_models", "Connect")
516                                .icon_position(IconPosition::Start)
517                                .icon(IconName::ArrowCircle)
518                                .on_click(cx.listener(move |this, _, _window, cx| {
519                                    this.retry_connection(cx)
520                                }))
521                                .into_any_element()
522                        }),
523                )
524                .into_any()
525        }
526    }
527}