lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelToolChoice,
  8};
  9use language_model::{
 10    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 12    LanguageModelRequest, RateLimiter, Role,
 13};
 14use lmstudio::{
 15    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 16    stream_chat_completion,
 17};
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsStore};
 21use std::{collections::BTreeMap, sync::Arc};
 22use ui::{ButtonLike, Indicator, List, prelude::*};
 23use util::ResultExt;
 24
 25use crate::AllLanguageModelSettings;
 26use crate::ui::InstructionListItem;
 27
 28const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 29const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 30const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 31
 32const PROVIDER_ID: &str = "lmstudio";
 33const PROVIDER_NAME: &str = "LM Studio";
 34
 35#[derive(Default, Debug, Clone, PartialEq)]
 36pub struct LmStudioSettings {
 37    pub api_url: String,
 38    pub available_models: Vec<AvailableModel>,
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 42pub struct AvailableModel {
 43    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 44    pub name: String,
 45    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 46    pub display_name: Option<String>,
 47    /// The model's context window size.
 48    pub max_tokens: usize,
 49}
 50
 51pub struct LmStudioLanguageModelProvider {
 52    http_client: Arc<dyn HttpClient>,
 53    state: gpui::Entity<State>,
 54}
 55
 56pub struct State {
 57    http_client: Arc<dyn HttpClient>,
 58    available_models: Vec<lmstudio::Model>,
 59    fetch_model_task: Option<Task<Result<()>>>,
 60    _subscription: Subscription,
 61}
 62
 63impl State {
 64    fn is_authenticated(&self) -> bool {
 65        !self.available_models.is_empty()
 66    }
 67
 68    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 69        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 70        let http_client = self.http_client.clone();
 71        let api_url = settings.api_url.clone();
 72
 73        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 74        cx.spawn(async move |this, cx| {
 75            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 76
 77            let mut models: Vec<lmstudio::Model> = models
 78                .into_iter()
 79                .filter(|model| model.r#type != ModelType::Embeddings)
 80                .map(|model| lmstudio::Model::new(&model.id, None, None))
 81                .collect();
 82
 83            models.sort_by(|a, b| a.name.cmp(&b.name));
 84
 85            this.update(cx, |this, cx| {
 86                this.available_models = models;
 87                cx.notify();
 88            })
 89        })
 90    }
 91
 92    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 93        let task = self.fetch_models(cx);
 94        self.fetch_model_task.replace(task);
 95    }
 96
 97    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 98        if self.is_authenticated() {
 99            return Task::ready(Ok(()));
100        }
101
102        let fetch_models_task = self.fetch_models(cx);
103        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
104    }
105}
106
107impl LmStudioLanguageModelProvider {
108    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
109        let this = Self {
110            http_client: http_client.clone(),
111            state: cx.new(|cx| {
112                let subscription = cx.observe_global::<SettingsStore>({
113                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
114                    move |this: &mut State, cx| {
115                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
116                        if &settings != new_settings {
117                            settings = new_settings.clone();
118                            this.restart_fetch_models_task(cx);
119                            cx.notify();
120                        }
121                    }
122                });
123
124                State {
125                    http_client,
126                    available_models: Default::default(),
127                    fetch_model_task: None,
128                    _subscription: subscription,
129                }
130            }),
131        };
132        this.state
133            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134        this
135    }
136}
137
138impl LanguageModelProviderState for LmStudioLanguageModelProvider {
139    type ObservableEntity = State;
140
141    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
142        Some(self.state.clone())
143    }
144}
145
146impl LanguageModelProvider for LmStudioLanguageModelProvider {
147    fn id(&self) -> LanguageModelProviderId {
148        LanguageModelProviderId(PROVIDER_ID.into())
149    }
150
151    fn name(&self) -> LanguageModelProviderName {
152        LanguageModelProviderName(PROVIDER_NAME.into())
153    }
154
155    fn icon(&self) -> IconName {
156        IconName::AiLmStudio
157    }
158
159    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
160        self.provided_models(cx).into_iter().next()
161    }
162
163    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
164        self.default_model(cx)
165    }
166
167    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
168        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
169
170        // Add models from the LM Studio API
171        for model in self.state.read(cx).available_models.iter() {
172            models.insert(model.name.clone(), model.clone());
173        }
174
175        // Override with available models from settings
176        for model in AllLanguageModelSettings::get_global(cx)
177            .lmstudio
178            .available_models
179            .iter()
180        {
181            models.insert(
182                model.name.clone(),
183                lmstudio::Model {
184                    name: model.name.clone(),
185                    display_name: model.display_name.clone(),
186                    max_tokens: model.max_tokens,
187                },
188            );
189        }
190
191        models
192            .into_values()
193            .map(|model| {
194                Arc::new(LmStudioLanguageModel {
195                    id: LanguageModelId::from(model.name.clone()),
196                    model: model.clone(),
197                    http_client: self.http_client.clone(),
198                    request_limiter: RateLimiter::new(4),
199                }) as Arc<dyn LanguageModel>
200            })
201            .collect()
202    }
203
204    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
205        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
206        let http_client = self.http_client.clone();
207        let api_url = settings.api_url.clone();
208        let id = model.id().0.to_string();
209        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
210            .detach_and_log_err(cx);
211    }
212
213    fn is_authenticated(&self, cx: &App) -> bool {
214        self.state.read(cx).is_authenticated()
215    }
216
217    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
218        self.state.update(cx, |state, cx| state.authenticate(cx))
219    }
220
221    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
222        let state = self.state.clone();
223        cx.new(|cx| ConfigurationView::new(state, cx)).into()
224    }
225
226    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
227        self.state.update(cx, |state, cx| state.fetch_models(cx))
228    }
229}
230
231pub struct LmStudioLanguageModel {
232    id: LanguageModelId,
233    model: lmstudio::Model,
234    http_client: Arc<dyn HttpClient>,
235    request_limiter: RateLimiter,
236}
237
238impl LmStudioLanguageModel {
239    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
240        ChatCompletionRequest {
241            model: self.model.name.clone(),
242            messages: request
243                .messages
244                .into_iter()
245                .map(|msg| match msg.role {
246                    Role::User => ChatMessage::User {
247                        content: msg.string_contents(),
248                    },
249                    Role::Assistant => ChatMessage::Assistant {
250                        content: Some(msg.string_contents()),
251                        tool_calls: None,
252                    },
253                    Role::System => ChatMessage::System {
254                        content: msg.string_contents(),
255                    },
256                })
257                .collect(),
258            stream: true,
259            max_tokens: Some(-1),
260            stop: Some(request.stop),
261            temperature: request.temperature.or(Some(0.0)),
262            tools: vec![],
263        }
264    }
265}
266
267impl LanguageModel for LmStudioLanguageModel {
268    fn id(&self) -> LanguageModelId {
269        self.id.clone()
270    }
271
272    fn name(&self) -> LanguageModelName {
273        LanguageModelName::from(self.model.display_name().to_string())
274    }
275
276    fn provider_id(&self) -> LanguageModelProviderId {
277        LanguageModelProviderId(PROVIDER_ID.into())
278    }
279
280    fn provider_name(&self) -> LanguageModelProviderName {
281        LanguageModelProviderName(PROVIDER_NAME.into())
282    }
283
284    fn supports_tools(&self) -> bool {
285        false
286    }
287
288    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
289        false
290    }
291
292    fn telemetry_id(&self) -> String {
293        format!("lmstudio/{}", self.model.id())
294    }
295
296    fn max_token_count(&self) -> usize {
297        self.model.max_token_count()
298    }
299
300    fn count_tokens(
301        &self,
302        request: LanguageModelRequest,
303        _cx: &App,
304    ) -> BoxFuture<'static, Result<usize>> {
305        // Endpoint for this is coming soon. In the meantime, hacky estimation
306        let token_count = request
307            .messages
308            .iter()
309            .map(|msg| msg.string_contents().split_whitespace().count())
310            .sum::<usize>();
311
312        let estimated_tokens = (token_count as f64 * 0.75) as usize;
313        async move { Ok(estimated_tokens) }.boxed()
314    }
315
316    fn stream_completion(
317        &self,
318        request: LanguageModelRequest,
319        cx: &AsyncApp,
320    ) -> BoxFuture<
321        'static,
322        Result<
323            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
324        >,
325    > {
326        let request = self.to_lmstudio_request(request);
327
328        let http_client = self.http_client.clone();
329        let Ok(api_url) = cx.update(|cx| {
330            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
331            settings.api_url.clone()
332        }) else {
333            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
334        };
335
336        let future = self.request_limiter.stream(async move {
337            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
338
339            // Create a stream mapper to handle content across multiple deltas
340            let stream_mapper = LmStudioStreamMapper::new();
341
342            let stream = response
343                .map(move |response| {
344                    response.and_then(|fragment| stream_mapper.process_fragment(fragment))
345                })
346                .filter_map(|result| async move {
347                    match result {
348                        Ok(Some(content)) => Some(Ok(content)),
349                        Ok(None) => None,
350                        Err(error) => Some(Err(error)),
351                    }
352                })
353                .boxed();
354
355            Ok(stream)
356        });
357
358        async move {
359            Ok(future
360                .await?
361                .map(|result| {
362                    result
363                        .map(LanguageModelCompletionEvent::Text)
364                        .map_err(LanguageModelCompletionError::Other)
365                })
366                .boxed())
367        }
368        .boxed()
369    }
370}
371
372// This will be more useful when we implement tool calling. Currently keeping it empty.
373struct LmStudioStreamMapper {}
374
375impl LmStudioStreamMapper {
376    fn new() -> Self {
377        Self {}
378    }
379
380    fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
381        // Most of the time, there will be only one choice
382        let Some(choice) = fragment.choices.first() else {
383            return Ok(None);
384        };
385
386        // Extract the delta content
387        if let Ok(delta) =
388            serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
389        {
390            if let Some(content) = delta.content {
391                if !content.is_empty() {
392                    return Ok(Some(content));
393                }
394            }
395        }
396
397        // If there's a finish_reason, we're done
398        if choice.finish_reason.is_some() {
399            return Ok(None);
400        }
401
402        Ok(None)
403    }
404}
405
406struct ConfigurationView {
407    state: gpui::Entity<State>,
408    loading_models_task: Option<Task<()>>,
409}
410
411impl ConfigurationView {
412    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
413        let loading_models_task = Some(cx.spawn({
414            let state = state.clone();
415            async move |this, cx| {
416                if let Some(task) = state
417                    .update(cx, |state, cx| state.authenticate(cx))
418                    .log_err()
419                {
420                    task.await.log_err();
421                }
422                this.update(cx, |this, cx| {
423                    this.loading_models_task = None;
424                    cx.notify();
425                })
426                .log_err();
427            }
428        }));
429
430        Self {
431            state,
432            loading_models_task,
433        }
434    }
435
436    fn retry_connection(&self, cx: &mut App) {
437        self.state
438            .update(cx, |state, cx| state.fetch_models(cx))
439            .detach_and_log_err(cx);
440    }
441}
442
443impl Render for ConfigurationView {
444    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
445        let is_authenticated = self.state.read(cx).is_authenticated();
446
447        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
448
449        if self.loading_models_task.is_some() {
450            div().child(Label::new("Loading models...")).into_any()
451        } else {
452            v_flex()
453                .gap_2()
454                .child(
455                    v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
456                        List::new()
457                            .child(InstructionListItem::text_only(
458                                "LM Studio needs to be running with at least one model downloaded.",
459                            ))
460                            .child(InstructionListItem::text_only(
461                                "To get your first model, try running `lms get qwen2.5-coder-7b`",
462                            )),
463                    ),
464                )
465                .child(
466                    h_flex()
467                        .w_full()
468                        .justify_between()
469                        .gap_2()
470                        .child(
471                            h_flex()
472                                .w_full()
473                                .gap_2()
474                                .map(|this| {
475                                    if is_authenticated {
476                                        this.child(
477                                            Button::new("lmstudio-site", "LM Studio")
478                                                .style(ButtonStyle::Subtle)
479                                                .icon(IconName::ArrowUpRight)
480                                                .icon_size(IconSize::XSmall)
481                                                .icon_color(Color::Muted)
482                                                .on_click(move |_, _window, cx| {
483                                                    cx.open_url(LMSTUDIO_SITE)
484                                                })
485                                                .into_any_element(),
486                                        )
487                                    } else {
488                                        this.child(
489                                            Button::new(
490                                                "download_lmstudio_button",
491                                                "Download LM Studio",
492                                            )
493                                            .style(ButtonStyle::Subtle)
494                                            .icon(IconName::ArrowUpRight)
495                                            .icon_size(IconSize::XSmall)
496                                            .icon_color(Color::Muted)
497                                            .on_click(move |_, _window, cx| {
498                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
499                                            })
500                                            .into_any_element(),
501                                        )
502                                    }
503                                })
504                                .child(
505                                    Button::new("view-models", "Model Catalog")
506                                        .style(ButtonStyle::Subtle)
507                                        .icon(IconName::ArrowUpRight)
508                                        .icon_size(IconSize::XSmall)
509                                        .icon_color(Color::Muted)
510                                        .on_click(move |_, _window, cx| {
511                                            cx.open_url(LMSTUDIO_CATALOG_URL)
512                                        }),
513                                ),
514                        )
515                        .map(|this| {
516                            if is_authenticated {
517                                this.child(
518                                    ButtonLike::new("connected")
519                                        .disabled(true)
520                                        .cursor_style(gpui::CursorStyle::Arrow)
521                                        .child(
522                                            h_flex()
523                                                .gap_2()
524                                                .child(Indicator::dot().color(Color::Success))
525                                                .child(Label::new("Connected"))
526                                                .into_any_element(),
527                                        ),
528                                )
529                            } else {
530                                this.child(
531                                    Button::new("retry_lmstudio_models", "Connect")
532                                        .icon_position(IconPosition::Start)
533                                        .icon_size(IconSize::XSmall)
534                                        .icon(IconName::Play)
535                                        .on_click(cx.listener(move |this, _, _window, cx| {
536                                            this.retry_connection(cx)
537                                        })),
538                                )
539                            }
540                        }),
541                )
542                .into_any()
543        }
544    }
545}