lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7};
  8use language_model::{
  9    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 10    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 11    LanguageModelRequest, RateLimiter, Role,
 12};
 13use lmstudio::{
 14    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 15    stream_chat_completion,
 16};
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use settings::{Settings, SettingsStore};
 20use std::{collections::BTreeMap, sync::Arc};
 21use ui::{ButtonLike, Indicator, List, prelude::*};
 22use util::ResultExt;
 23
 24use crate::AllLanguageModelSettings;
 25use crate::ui::InstructionListItem;
 26
 27const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 28const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 29const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 30
 31const PROVIDER_ID: &str = "lmstudio";
 32const PROVIDER_NAME: &str = "LM Studio";
 33
 34#[derive(Default, Debug, Clone, PartialEq)]
 35pub struct LmStudioSettings {
 36    pub api_url: String,
 37    pub available_models: Vec<AvailableModel>,
 38}
 39
 40#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 41pub struct AvailableModel {
 42    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 43    pub name: String,
 44    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 45    pub display_name: Option<String>,
 46    /// The model's context window size.
 47    pub max_tokens: usize,
 48}
 49
 50pub struct LmStudioLanguageModelProvider {
 51    http_client: Arc<dyn HttpClient>,
 52    state: gpui::Entity<State>,
 53}
 54
 55pub struct State {
 56    http_client: Arc<dyn HttpClient>,
 57    available_models: Vec<lmstudio::Model>,
 58    fetch_model_task: Option<Task<Result<()>>>,
 59    _subscription: Subscription,
 60}
 61
 62impl State {
 63    fn is_authenticated(&self) -> bool {
 64        !self.available_models.is_empty()
 65    }
 66
 67    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 68        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 69        let http_client = self.http_client.clone();
 70        let api_url = settings.api_url.clone();
 71
 72        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 73        cx.spawn(async move |this, cx| {
 74            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 75
 76            let mut models: Vec<lmstudio::Model> = models
 77                .into_iter()
 78                .filter(|model| model.r#type != ModelType::Embeddings)
 79                .map(|model| lmstudio::Model::new(&model.id, None, None))
 80                .collect();
 81
 82            models.sort_by(|a, b| a.name.cmp(&b.name));
 83
 84            this.update(cx, |this, cx| {
 85                this.available_models = models;
 86                cx.notify();
 87            })
 88        })
 89    }
 90
 91    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 92        let task = self.fetch_models(cx);
 93        self.fetch_model_task.replace(task);
 94    }
 95
 96    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 97        if self.is_authenticated() {
 98            return Task::ready(Ok(()));
 99        }
100
101        let fetch_models_task = self.fetch_models(cx);
102        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
103    }
104}
105
106impl LmStudioLanguageModelProvider {
107    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
108        let this = Self {
109            http_client: http_client.clone(),
110            state: cx.new(|cx| {
111                let subscription = cx.observe_global::<SettingsStore>({
112                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
113                    move |this: &mut State, cx| {
114                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
115                        if &settings != new_settings {
116                            settings = new_settings.clone();
117                            this.restart_fetch_models_task(cx);
118                            cx.notify();
119                        }
120                    }
121                });
122
123                State {
124                    http_client,
125                    available_models: Default::default(),
126                    fetch_model_task: None,
127                    _subscription: subscription,
128                }
129            }),
130        };
131        this.state
132            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
133        this
134    }
135}
136
137impl LanguageModelProviderState for LmStudioLanguageModelProvider {
138    type ObservableEntity = State;
139
140    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
141        Some(self.state.clone())
142    }
143}
144
145impl LanguageModelProvider for LmStudioLanguageModelProvider {
146    fn id(&self) -> LanguageModelProviderId {
147        LanguageModelProviderId(PROVIDER_ID.into())
148    }
149
150    fn name(&self) -> LanguageModelProviderName {
151        LanguageModelProviderName(PROVIDER_NAME.into())
152    }
153
154    fn icon(&self) -> IconName {
155        IconName::AiLmStudio
156    }
157
158    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
159        self.provided_models(cx).into_iter().next()
160    }
161
162    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
163        self.default_model(cx)
164    }
165
166    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
167        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
168
169        // Add models from the LM Studio API
170        for model in self.state.read(cx).available_models.iter() {
171            models.insert(model.name.clone(), model.clone());
172        }
173
174        // Override with available models from settings
175        for model in AllLanguageModelSettings::get_global(cx)
176            .lmstudio
177            .available_models
178            .iter()
179        {
180            models.insert(
181                model.name.clone(),
182                lmstudio::Model {
183                    name: model.name.clone(),
184                    display_name: model.display_name.clone(),
185                    max_tokens: model.max_tokens,
186                },
187            );
188        }
189
190        models
191            .into_values()
192            .map(|model| {
193                Arc::new(LmStudioLanguageModel {
194                    id: LanguageModelId::from(model.name.clone()),
195                    model: model.clone(),
196                    http_client: self.http_client.clone(),
197                    request_limiter: RateLimiter::new(4),
198                }) as Arc<dyn LanguageModel>
199            })
200            .collect()
201    }
202
203    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
204        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
205        let http_client = self.http_client.clone();
206        let api_url = settings.api_url.clone();
207        let id = model.id().0.to_string();
208        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
209            .detach_and_log_err(cx);
210    }
211
212    fn is_authenticated(&self, cx: &App) -> bool {
213        self.state.read(cx).is_authenticated()
214    }
215
216    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
217        self.state.update(cx, |state, cx| state.authenticate(cx))
218    }
219
220    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
221        let state = self.state.clone();
222        cx.new(|cx| ConfigurationView::new(state, cx)).into()
223    }
224
225    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
226        self.state.update(cx, |state, cx| state.fetch_models(cx))
227    }
228}
229
230pub struct LmStudioLanguageModel {
231    id: LanguageModelId,
232    model: lmstudio::Model,
233    http_client: Arc<dyn HttpClient>,
234    request_limiter: RateLimiter,
235}
236
237impl LmStudioLanguageModel {
238    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
239        ChatCompletionRequest {
240            model: self.model.name.clone(),
241            messages: request
242                .messages
243                .into_iter()
244                .map(|msg| match msg.role {
245                    Role::User => ChatMessage::User {
246                        content: msg.string_contents(),
247                    },
248                    Role::Assistant => ChatMessage::Assistant {
249                        content: Some(msg.string_contents()),
250                        tool_calls: None,
251                    },
252                    Role::System => ChatMessage::System {
253                        content: msg.string_contents(),
254                    },
255                })
256                .collect(),
257            stream: true,
258            max_tokens: Some(-1),
259            stop: Some(request.stop),
260            temperature: request.temperature.or(Some(0.0)),
261            tools: vec![],
262        }
263    }
264}
265
266impl LanguageModel for LmStudioLanguageModel {
267    fn id(&self) -> LanguageModelId {
268        self.id.clone()
269    }
270
271    fn name(&self) -> LanguageModelName {
272        LanguageModelName::from(self.model.display_name().to_string())
273    }
274
275    fn provider_id(&self) -> LanguageModelProviderId {
276        LanguageModelProviderId(PROVIDER_ID.into())
277    }
278
279    fn provider_name(&self) -> LanguageModelProviderName {
280        LanguageModelProviderName(PROVIDER_NAME.into())
281    }
282
283    fn supports_tools(&self) -> bool {
284        false
285    }
286
287    fn telemetry_id(&self) -> String {
288        format!("lmstudio/{}", self.model.id())
289    }
290
291    fn max_token_count(&self) -> usize {
292        self.model.max_token_count()
293    }
294
295    fn count_tokens(
296        &self,
297        request: LanguageModelRequest,
298        _cx: &App,
299    ) -> BoxFuture<'static, Result<usize>> {
300        // Endpoint for this is coming soon. In the meantime, hacky estimation
301        let token_count = request
302            .messages
303            .iter()
304            .map(|msg| msg.string_contents().split_whitespace().count())
305            .sum::<usize>();
306
307        let estimated_tokens = (token_count as f64 * 0.75) as usize;
308        async move { Ok(estimated_tokens) }.boxed()
309    }
310
311    fn stream_completion(
312        &self,
313        request: LanguageModelRequest,
314        cx: &AsyncApp,
315    ) -> BoxFuture<
316        'static,
317        Result<
318            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
319        >,
320    > {
321        let request = self.to_lmstudio_request(request);
322
323        let http_client = self.http_client.clone();
324        let Ok(api_url) = cx.update(|cx| {
325            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
326            settings.api_url.clone()
327        }) else {
328            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
329        };
330
331        let future = self.request_limiter.stream(async move {
332            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
333
334            // Create a stream mapper to handle content across multiple deltas
335            let stream_mapper = LmStudioStreamMapper::new();
336
337            let stream = response
338                .map(move |response| {
339                    response.and_then(|fragment| stream_mapper.process_fragment(fragment))
340                })
341                .filter_map(|result| async move {
342                    match result {
343                        Ok(Some(content)) => Some(Ok(content)),
344                        Ok(None) => None,
345                        Err(error) => Some(Err(error)),
346                    }
347                })
348                .boxed();
349
350            Ok(stream)
351        });
352
353        async move {
354            Ok(future
355                .await?
356                .map(|result| {
357                    result
358                        .map(LanguageModelCompletionEvent::Text)
359                        .map_err(LanguageModelCompletionError::Other)
360                })
361                .boxed())
362        }
363        .boxed()
364    }
365}
366
367// This will be more useful when we implement tool calling. Currently keeping it empty.
368struct LmStudioStreamMapper {}
369
370impl LmStudioStreamMapper {
371    fn new() -> Self {
372        Self {}
373    }
374
375    fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
376        // Most of the time, there will be only one choice
377        let Some(choice) = fragment.choices.first() else {
378            return Ok(None);
379        };
380
381        // Extract the delta content
382        if let Ok(delta) =
383            serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
384        {
385            if let Some(content) = delta.content {
386                if !content.is_empty() {
387                    return Ok(Some(content));
388                }
389            }
390        }
391
392        // If there's a finish_reason, we're done
393        if choice.finish_reason.is_some() {
394            return Ok(None);
395        }
396
397        Ok(None)
398    }
399}
400
401struct ConfigurationView {
402    state: gpui::Entity<State>,
403    loading_models_task: Option<Task<()>>,
404}
405
406impl ConfigurationView {
407    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
408        let loading_models_task = Some(cx.spawn({
409            let state = state.clone();
410            async move |this, cx| {
411                if let Some(task) = state
412                    .update(cx, |state, cx| state.authenticate(cx))
413                    .log_err()
414                {
415                    task.await.log_err();
416                }
417                this.update(cx, |this, cx| {
418                    this.loading_models_task = None;
419                    cx.notify();
420                })
421                .log_err();
422            }
423        }));
424
425        Self {
426            state,
427            loading_models_task,
428        }
429    }
430
431    fn retry_connection(&self, cx: &mut App) {
432        self.state
433            .update(cx, |state, cx| state.fetch_models(cx))
434            .detach_and_log_err(cx);
435    }
436}
437
438impl Render for ConfigurationView {
439    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
440        let is_authenticated = self.state.read(cx).is_authenticated();
441
442        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
443
444        if self.loading_models_task.is_some() {
445            div().child(Label::new("Loading models...")).into_any()
446        } else {
447            v_flex()
448                .gap_2()
449                .child(
450                    v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
451                        List::new()
452                            .child(InstructionListItem::text_only(
453                                "LM Studio needs to be running with at least one model downloaded.",
454                            ))
455                            .child(InstructionListItem::text_only(
456                                "To get your first model, try running `lms get qwen2.5-coder-7b`",
457                            )),
458                    ),
459                )
460                .child(
461                    h_flex()
462                        .w_full()
463                        .justify_between()
464                        .gap_2()
465                        .child(
466                            h_flex()
467                                .w_full()
468                                .gap_2()
469                                .map(|this| {
470                                    if is_authenticated {
471                                        this.child(
472                                            Button::new("lmstudio-site", "LM Studio")
473                                                .style(ButtonStyle::Subtle)
474                                                .icon(IconName::ArrowUpRight)
475                                                .icon_size(IconSize::XSmall)
476                                                .icon_color(Color::Muted)
477                                                .on_click(move |_, _window, cx| {
478                                                    cx.open_url(LMSTUDIO_SITE)
479                                                })
480                                                .into_any_element(),
481                                        )
482                                    } else {
483                                        this.child(
484                                            Button::new(
485                                                "download_lmstudio_button",
486                                                "Download LM Studio",
487                                            )
488                                            .style(ButtonStyle::Subtle)
489                                            .icon(IconName::ArrowUpRight)
490                                            .icon_size(IconSize::XSmall)
491                                            .icon_color(Color::Muted)
492                                            .on_click(move |_, _window, cx| {
493                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
494                                            })
495                                            .into_any_element(),
496                                        )
497                                    }
498                                })
499                                .child(
500                                    Button::new("view-models", "Model Catalog")
501                                        .style(ButtonStyle::Subtle)
502                                        .icon(IconName::ArrowUpRight)
503                                        .icon_size(IconSize::XSmall)
504                                        .icon_color(Color::Muted)
505                                        .on_click(move |_, _window, cx| {
506                                            cx.open_url(LMSTUDIO_CATALOG_URL)
507                                        }),
508                                ),
509                        )
510                        .map(|this| {
511                            if is_authenticated {
512                                this.child(
513                                    ButtonLike::new("connected")
514                                        .disabled(true)
515                                        .cursor_style(gpui::CursorStyle::Arrow)
516                                        .child(
517                                            h_flex()
518                                                .gap_2()
519                                                .child(Indicator::dot().color(Color::Success))
520                                                .child(Label::new("Connected"))
521                                                .into_any_element(),
522                                        ),
523                                )
524                            } else {
525                                this.child(
526                                    Button::new("retry_lmstudio_models", "Connect")
527                                        .icon_position(IconPosition::Start)
528                                        .icon_size(IconSize::XSmall)
529                                        .icon(IconName::Play)
530                                        .on_click(cx.listener(move |this, _, _window, cx| {
531                                            this.retry_connection(cx)
532                                        })),
533                                )
534                            }
535                        }),
536                )
537                .into_any()
538        }
539    }
540}