lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use lmstudio::{
 12    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 13    stream_chat_completion,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{ButtonLike, Indicator, prelude::*};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23
 24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 27
 28const PROVIDER_ID: &str = "lmstudio";
 29const PROVIDER_NAME: &str = "LM Studio";
 30
 31#[derive(Default, Debug, Clone, PartialEq)]
 32pub struct LmStudioSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The model's context window size.
 44    pub max_tokens: usize,
 45}
 46
 47pub struct LmStudioLanguageModelProvider {
 48    http_client: Arc<dyn HttpClient>,
 49    state: gpui::Entity<State>,
 50}
 51
 52pub struct State {
 53    http_client: Arc<dyn HttpClient>,
 54    available_models: Vec<lmstudio::Model>,
 55    fetch_model_task: Option<Task<Result<()>>>,
 56    _subscription: Subscription,
 57}
 58
 59impl State {
 60    fn is_authenticated(&self) -> bool {
 61        !self.available_models.is_empty()
 62    }
 63
 64    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 65        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 66        let http_client = self.http_client.clone();
 67        let api_url = settings.api_url.clone();
 68
 69        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 70        cx.spawn(async move |this, cx| {
 71            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 72
 73            let mut models: Vec<lmstudio::Model> = models
 74                .into_iter()
 75                .filter(|model| model.r#type != ModelType::Embeddings)
 76                .map(|model| lmstudio::Model::new(&model.id, None, None))
 77                .collect();
 78
 79            models.sort_by(|a, b| a.name.cmp(&b.name));
 80
 81            this.update(cx, |this, cx| {
 82                this.available_models = models;
 83                cx.notify();
 84            })
 85        })
 86    }
 87
 88    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 89        let task = self.fetch_models(cx);
 90        self.fetch_model_task.replace(task);
 91    }
 92
 93    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 94        if self.is_authenticated() {
 95            return Task::ready(Ok(()));
 96        }
 97
 98        let fetch_models_task = self.fetch_models(cx);
 99        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
100    }
101}
102
103impl LmStudioLanguageModelProvider {
104    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105        let this = Self {
106            http_client: http_client.clone(),
107            state: cx.new(|cx| {
108                let subscription = cx.observe_global::<SettingsStore>({
109                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110                    move |this: &mut State, cx| {
111                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112                        if &settings != new_settings {
113                            settings = new_settings.clone();
114                            this.restart_fetch_models_task(cx);
115                            cx.notify();
116                        }
117                    }
118                });
119
120                State {
121                    http_client,
122                    available_models: Default::default(),
123                    fetch_model_task: None,
124                    _subscription: subscription,
125                }
126            }),
127        };
128        this.state
129            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130        this
131    }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135    type ObservableEntity = State;
136
137    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138        Some(self.state.clone())
139    }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143    fn id(&self) -> LanguageModelProviderId {
144        LanguageModelProviderId(PROVIDER_ID.into())
145    }
146
147    fn name(&self) -> LanguageModelProviderName {
148        LanguageModelProviderName(PROVIDER_NAME.into())
149    }
150
151    fn icon(&self) -> IconName {
152        IconName::AiLmStudio
153    }
154
155    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
156        self.provided_models(cx).into_iter().next()
157    }
158
159    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
161
162        // Add models from the LM Studio API
163        for model in self.state.read(cx).available_models.iter() {
164            models.insert(model.name.clone(), model.clone());
165        }
166
167        // Override with available models from settings
168        for model in AllLanguageModelSettings::get_global(cx)
169            .lmstudio
170            .available_models
171            .iter()
172        {
173            models.insert(
174                model.name.clone(),
175                lmstudio::Model {
176                    name: model.name.clone(),
177                    display_name: model.display_name.clone(),
178                    max_tokens: model.max_tokens,
179                },
180            );
181        }
182
183        models
184            .into_values()
185            .map(|model| {
186                Arc::new(LmStudioLanguageModel {
187                    id: LanguageModelId::from(model.name.clone()),
188                    model: model.clone(),
189                    http_client: self.http_client.clone(),
190                    request_limiter: RateLimiter::new(4),
191                }) as Arc<dyn LanguageModel>
192            })
193            .collect()
194    }
195
196    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
197        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
198        let http_client = self.http_client.clone();
199        let api_url = settings.api_url.clone();
200        let id = model.id().0.to_string();
201        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
202            .detach_and_log_err(cx);
203    }
204
205    fn is_authenticated(&self, cx: &App) -> bool {
206        self.state.read(cx).is_authenticated()
207    }
208
209    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
210        self.state.update(cx, |state, cx| state.authenticate(cx))
211    }
212
213    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
214        let state = self.state.clone();
215        cx.new(|cx| ConfigurationView::new(state, cx)).into()
216    }
217
218    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219        self.state.update(cx, |state, cx| state.fetch_models(cx))
220    }
221}
222
223pub struct LmStudioLanguageModel {
224    id: LanguageModelId,
225    model: lmstudio::Model,
226    http_client: Arc<dyn HttpClient>,
227    request_limiter: RateLimiter,
228}
229
230impl LmStudioLanguageModel {
231    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
232        ChatCompletionRequest {
233            model: self.model.name.clone(),
234            messages: request
235                .messages
236                .into_iter()
237                .map(|msg| match msg.role {
238                    Role::User => ChatMessage::User {
239                        content: msg.string_contents(),
240                    },
241                    Role::Assistant => ChatMessage::Assistant {
242                        content: Some(msg.string_contents()),
243                        tool_calls: None,
244                    },
245                    Role::System => ChatMessage::System {
246                        content: msg.string_contents(),
247                    },
248                })
249                .collect(),
250            stream: true,
251            max_tokens: Some(-1),
252            stop: Some(request.stop),
253            temperature: request.temperature.or(Some(0.0)),
254            tools: vec![],
255        }
256    }
257}
258
259impl LanguageModel for LmStudioLanguageModel {
260    fn id(&self) -> LanguageModelId {
261        self.id.clone()
262    }
263
264    fn name(&self) -> LanguageModelName {
265        LanguageModelName::from(self.model.display_name().to_string())
266    }
267
268    fn provider_id(&self) -> LanguageModelProviderId {
269        LanguageModelProviderId(PROVIDER_ID.into())
270    }
271
272    fn provider_name(&self) -> LanguageModelProviderName {
273        LanguageModelProviderName(PROVIDER_NAME.into())
274    }
275
276    fn supports_tools(&self) -> bool {
277        false
278    }
279
280    fn telemetry_id(&self) -> String {
281        format!("lmstudio/{}", self.model.id())
282    }
283
284    fn max_token_count(&self) -> usize {
285        self.model.max_token_count()
286    }
287
288    fn count_tokens(
289        &self,
290        request: LanguageModelRequest,
291        _cx: &App,
292    ) -> BoxFuture<'static, Result<usize>> {
293        // Endpoint for this is coming soon. In the meantime, hacky estimation
294        let token_count = request
295            .messages
296            .iter()
297            .map(|msg| msg.string_contents().split_whitespace().count())
298            .sum::<usize>();
299
300        let estimated_tokens = (token_count as f64 * 0.75) as usize;
301        async move { Ok(estimated_tokens) }.boxed()
302    }
303
304    fn stream_completion(
305        &self,
306        request: LanguageModelRequest,
307        cx: &AsyncApp,
308    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
309        let request = self.to_lmstudio_request(request);
310
311        let http_client = self.http_client.clone();
312        let Ok(api_url) = cx.update(|cx| {
313            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
314            settings.api_url.clone()
315        }) else {
316            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
317        };
318
319        let future = self.request_limiter.stream(async move {
320            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
321            let stream = response
322                .filter_map(|response| async move {
323                    match response {
324                        Ok(fragment) => {
325                            // Skip empty deltas
326                            if fragment.choices[0].delta.is_object()
327                                && fragment.choices[0].delta.as_object().unwrap().is_empty()
328                            {
329                                return None;
330                            }
331
332                            // Try to parse the delta as ChatMessage
333                            if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
334                                fragment.choices[0].delta.clone(),
335                            ) {
336                                let content = match chat_message {
337                                    ChatMessage::User { content } => content,
338                                    ChatMessage::Assistant { content, .. } => {
339                                        content.unwrap_or_default()
340                                    }
341                                    ChatMessage::System { content } => content,
342                                };
343                                if !content.is_empty() {
344                                    Some(Ok(content))
345                                } else {
346                                    None
347                                }
348                            } else {
349                                None
350                            }
351                        }
352                        Err(error) => Some(Err(error)),
353                    }
354                })
355                .boxed();
356            Ok(stream)
357        });
358
359        async move {
360            Ok(future
361                .await?
362                .map(|result| result.map(LanguageModelCompletionEvent::Text))
363                .boxed())
364        }
365        .boxed()
366    }
367
368    fn use_any_tool(
369        &self,
370        _request: LanguageModelRequest,
371        _tool_name: String,
372        _tool_description: String,
373        _schema: serde_json::Value,
374        _cx: &AsyncApp,
375    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
376        async move { Ok(futures::stream::empty().boxed()) }.boxed()
377    }
378}
379
380struct ConfigurationView {
381    state: gpui::Entity<State>,
382    loading_models_task: Option<Task<()>>,
383}
384
385impl ConfigurationView {
386    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
387        let loading_models_task = Some(cx.spawn({
388            let state = state.clone();
389            async move |this, cx| {
390                if let Some(task) = state
391                    .update(cx, |state, cx| state.authenticate(cx))
392                    .log_err()
393                {
394                    task.await.log_err();
395                }
396                this.update(cx, |this, cx| {
397                    this.loading_models_task = None;
398                    cx.notify();
399                })
400                .log_err();
401            }
402        }));
403
404        Self {
405            state,
406            loading_models_task,
407        }
408    }
409
410    fn retry_connection(&self, cx: &mut App) {
411        self.state
412            .update(cx, |state, cx| state.fetch_models(cx))
413            .detach_and_log_err(cx);
414    }
415}
416
417impl Render for ConfigurationView {
418    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
419        let is_authenticated = self.state.read(cx).is_authenticated();
420
421        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
422        let lmstudio_reqs = "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
423
424        let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
425
426        if self.loading_models_task.is_some() {
427            div().child(Label::new("Loading models...")).into_any()
428        } else {
429            v_flex()
430                .size_full()
431                .gap_3()
432                .child(
433                    v_flex()
434                        .size_full()
435                        .gap_2()
436                        .p_1()
437                        .child(Label::new(lmstudio_intro))
438                        .child(Label::new(lmstudio_reqs))
439                        .child(
440                            h_flex()
441                                .gap_0p5()
442                                .child(Label::new("To get your first model, try running"))
443                                .child(
444                                    div()
445                                        .bg(inline_code_bg)
446                                        .px_1p5()
447                                        .rounded_sm()
448                                        .child(Label::new("lms get qwen2.5-coder-7b")),
449                                ),
450                        ),
451                )
452                .child(
453                    h_flex()
454                        .w_full()
455                        .pt_2()
456                        .justify_between()
457                        .gap_2()
458                        .child(
459                            h_flex()
460                                .w_full()
461                                .gap_2()
462                                .map(|this| {
463                                    if is_authenticated {
464                                        this.child(
465                                            Button::new("lmstudio-site", "LM Studio")
466                                                .style(ButtonStyle::Subtle)
467                                                .icon(IconName::ArrowUpRight)
468                                                .icon_size(IconSize::XSmall)
469                                                .icon_color(Color::Muted)
470                                                .on_click(move |_, _window, cx| {
471                                                    cx.open_url(LMSTUDIO_SITE)
472                                                })
473                                                .into_any_element(),
474                                        )
475                                    } else {
476                                        this.child(
477                                            Button::new(
478                                                "download_lmstudio_button",
479                                                "Download LM Studio",
480                                            )
481                                            .style(ButtonStyle::Subtle)
482                                            .icon(IconName::ArrowUpRight)
483                                            .icon_size(IconSize::XSmall)
484                                            .icon_color(Color::Muted)
485                                            .on_click(move |_, _window, cx| {
486                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
487                                            })
488                                            .into_any_element(),
489                                        )
490                                    }
491                                })
492                                .child(
493                                    Button::new("view-models", "Model Catalog")
494                                        .style(ButtonStyle::Subtle)
495                                        .icon(IconName::ArrowUpRight)
496                                        .icon_size(IconSize::XSmall)
497                                        .icon_color(Color::Muted)
498                                        .on_click(move |_, _window, cx| {
499                                            cx.open_url(LMSTUDIO_CATALOG_URL)
500                                        }),
501                                ),
502                        )
503                        .child(if is_authenticated {
504                            // This is only a button to ensure the spacing is correct
505                            // it should stay disabled
506                            ButtonLike::new("connected")
507                                .disabled(true)
508                                // Since this won't ever be clickable, we can use the arrow cursor
509                                .cursor_style(gpui::CursorStyle::Arrow)
510                                .child(
511                                    h_flex()
512                                        .gap_2()
513                                        .child(Indicator::dot().color(Color::Success))
514                                        .child(Label::new("Connected"))
515                                        .into_any_element(),
516                                )
517                                .into_any_element()
518                        } else {
519                            Button::new("retry_lmstudio_models", "Connect")
520                                .icon_position(IconPosition::Start)
521                                .icon(IconName::ArrowCircle)
522                                .on_click(cx.listener(move |this, _, _window, cx| {
523                                    this.retry_connection(cx)
524                                }))
525                                .into_any_element()
526                        }),
527                )
528                .into_any()
529        }
530    }
531}