lmstudio.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelToolChoice,
  8};
  9use language_model::{
 10    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 12    LanguageModelRequest, RateLimiter, Role,
 13};
 14use lmstudio::{
 15    ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
 16    stream_chat_completion,
 17};
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsStore};
 21use std::{collections::BTreeMap, sync::Arc};
 22use ui::{ButtonLike, Indicator, List, prelude::*};
 23use util::ResultExt;
 24
 25use crate::AllLanguageModelSettings;
 26use crate::ui::InstructionListItem;
 27
 28const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 29const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 30const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 31
 32const PROVIDER_ID: &str = "lmstudio";
 33const PROVIDER_NAME: &str = "LM Studio";
 34
 35#[derive(Default, Debug, Clone, PartialEq)]
 36pub struct LmStudioSettings {
 37    pub api_url: String,
 38    pub available_models: Vec<AvailableModel>,
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 42pub struct AvailableModel {
 43    /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
 44    pub name: String,
 45    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 46    pub display_name: Option<String>,
 47    /// The model's context window size.
 48    pub max_tokens: usize,
 49}
 50
 51pub struct LmStudioLanguageModelProvider {
 52    http_client: Arc<dyn HttpClient>,
 53    state: gpui::Entity<State>,
 54}
 55
 56pub struct State {
 57    http_client: Arc<dyn HttpClient>,
 58    available_models: Vec<lmstudio::Model>,
 59    fetch_model_task: Option<Task<Result<()>>>,
 60    _subscription: Subscription,
 61}
 62
 63impl State {
 64    fn is_authenticated(&self) -> bool {
 65        !self.available_models.is_empty()
 66    }
 67
 68    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 69        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
 70        let http_client = self.http_client.clone();
 71        let api_url = settings.api_url.clone();
 72
 73        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 74        cx.spawn(async move |this, cx| {
 75            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 76
 77            let mut models: Vec<lmstudio::Model> = models
 78                .into_iter()
 79                .filter(|model| model.r#type != ModelType::Embeddings)
 80                .map(|model| lmstudio::Model::new(&model.id, None, None))
 81                .collect();
 82
 83            models.sort_by(|a, b| a.name.cmp(&b.name));
 84
 85            this.update(cx, |this, cx| {
 86                this.available_models = models;
 87                cx.notify();
 88            })
 89        })
 90    }
 91
 92    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 93        let task = self.fetch_models(cx);
 94        self.fetch_model_task.replace(task);
 95    }
 96
 97    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 98        if self.is_authenticated() {
 99            return Task::ready(Ok(()));
100        }
101
102        let fetch_models_task = self.fetch_models(cx);
103        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
104    }
105}
106
107impl LmStudioLanguageModelProvider {
108    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
109        let this = Self {
110            http_client: http_client.clone(),
111            state: cx.new(|cx| {
112                let subscription = cx.observe_global::<SettingsStore>({
113                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
114                    move |this: &mut State, cx| {
115                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
116                        if &settings != new_settings {
117                            settings = new_settings.clone();
118                            this.restart_fetch_models_task(cx);
119                            cx.notify();
120                        }
121                    }
122                });
123
124                State {
125                    http_client,
126                    available_models: Default::default(),
127                    fetch_model_task: None,
128                    _subscription: subscription,
129                }
130            }),
131        };
132        this.state
133            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134        this
135    }
136}
137
138impl LanguageModelProviderState for LmStudioLanguageModelProvider {
139    type ObservableEntity = State;
140
141    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
142        Some(self.state.clone())
143    }
144}
145
146impl LanguageModelProvider for LmStudioLanguageModelProvider {
147    fn id(&self) -> LanguageModelProviderId {
148        LanguageModelProviderId(PROVIDER_ID.into())
149    }
150
151    fn name(&self) -> LanguageModelProviderName {
152        LanguageModelProviderName(PROVIDER_NAME.into())
153    }
154
155    fn icon(&self) -> IconName {
156        IconName::AiLmStudio
157    }
158
159    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
160        self.provided_models(cx).into_iter().next()
161    }
162
163    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
164        self.default_model(cx)
165    }
166
167    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
168        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
169
170        // Add models from the LM Studio API
171        for model in self.state.read(cx).available_models.iter() {
172            models.insert(model.name.clone(), model.clone());
173        }
174
175        // Override with available models from settings
176        for model in AllLanguageModelSettings::get_global(cx)
177            .lmstudio
178            .available_models
179            .iter()
180        {
181            models.insert(
182                model.name.clone(),
183                lmstudio::Model {
184                    name: model.name.clone(),
185                    display_name: model.display_name.clone(),
186                    max_tokens: model.max_tokens,
187                },
188            );
189        }
190
191        models
192            .into_values()
193            .map(|model| {
194                Arc::new(LmStudioLanguageModel {
195                    id: LanguageModelId::from(model.name.clone()),
196                    model: model.clone(),
197                    http_client: self.http_client.clone(),
198                    request_limiter: RateLimiter::new(4),
199                }) as Arc<dyn LanguageModel>
200            })
201            .collect()
202    }
203
204    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
205        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
206        let http_client = self.http_client.clone();
207        let api_url = settings.api_url.clone();
208        let id = model.id().0.to_string();
209        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
210            .detach_and_log_err(cx);
211    }
212
213    fn is_authenticated(&self, cx: &App) -> bool {
214        self.state.read(cx).is_authenticated()
215    }
216
217    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
218        self.state.update(cx, |state, cx| state.authenticate(cx))
219    }
220
221    fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
222        let state = self.state.clone();
223        cx.new(|cx| ConfigurationView::new(state, cx)).into()
224    }
225
226    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
227        self.state.update(cx, |state, cx| state.fetch_models(cx))
228    }
229}
230
231pub struct LmStudioLanguageModel {
232    id: LanguageModelId,
233    model: lmstudio::Model,
234    http_client: Arc<dyn HttpClient>,
235    request_limiter: RateLimiter,
236}
237
238impl LmStudioLanguageModel {
239    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
240        ChatCompletionRequest {
241            model: self.model.name.clone(),
242            messages: request
243                .messages
244                .into_iter()
245                .map(|msg| match msg.role {
246                    Role::User => ChatMessage::User {
247                        content: msg.string_contents(),
248                    },
249                    Role::Assistant => ChatMessage::Assistant {
250                        content: Some(msg.string_contents()),
251                        tool_calls: None,
252                    },
253                    Role::System => ChatMessage::System {
254                        content: msg.string_contents(),
255                    },
256                })
257                .collect(),
258            stream: true,
259            max_tokens: Some(-1),
260            stop: Some(request.stop),
261            temperature: request.temperature.or(Some(0.0)),
262            tools: vec![],
263        }
264    }
265}
266
267impl LanguageModel for LmStudioLanguageModel {
268    fn id(&self) -> LanguageModelId {
269        self.id.clone()
270    }
271
272    fn name(&self) -> LanguageModelName {
273        LanguageModelName::from(self.model.display_name().to_string())
274    }
275
276    fn provider_id(&self) -> LanguageModelProviderId {
277        LanguageModelProviderId(PROVIDER_ID.into())
278    }
279
280    fn provider_name(&self) -> LanguageModelProviderName {
281        LanguageModelProviderName(PROVIDER_NAME.into())
282    }
283
284    fn supports_tools(&self) -> bool {
285        false
286    }
287
288    fn supports_images(&self) -> bool {
289        false
290    }
291
292    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
293        false
294    }
295
296    fn telemetry_id(&self) -> String {
297        format!("lmstudio/{}", self.model.id())
298    }
299
300    fn max_token_count(&self) -> usize {
301        self.model.max_token_count()
302    }
303
304    fn count_tokens(
305        &self,
306        request: LanguageModelRequest,
307        _cx: &App,
308    ) -> BoxFuture<'static, Result<usize>> {
309        // Endpoint for this is coming soon. In the meantime, hacky estimation
310        let token_count = request
311            .messages
312            .iter()
313            .map(|msg| msg.string_contents().split_whitespace().count())
314            .sum::<usize>();
315
316        let estimated_tokens = (token_count as f64 * 0.75) as usize;
317        async move { Ok(estimated_tokens) }.boxed()
318    }
319
320    fn stream_completion(
321        &self,
322        request: LanguageModelRequest,
323        cx: &AsyncApp,
324    ) -> BoxFuture<
325        'static,
326        Result<
327            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
328        >,
329    > {
330        let request = self.to_lmstudio_request(request);
331
332        let http_client = self.http_client.clone();
333        let Ok(api_url) = cx.update(|cx| {
334            let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
335            settings.api_url.clone()
336        }) else {
337            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
338        };
339
340        let future = self.request_limiter.stream(async move {
341            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
342
343            // Create a stream mapper to handle content across multiple deltas
344            let stream_mapper = LmStudioStreamMapper::new();
345
346            let stream = response
347                .map(move |response| {
348                    response.and_then(|fragment| stream_mapper.process_fragment(fragment))
349                })
350                .filter_map(|result| async move {
351                    match result {
352                        Ok(Some(content)) => Some(Ok(content)),
353                        Ok(None) => None,
354                        Err(error) => Some(Err(error)),
355                    }
356                })
357                .boxed();
358
359            Ok(stream)
360        });
361
362        async move {
363            Ok(future
364                .await?
365                .map(|result| {
366                    result
367                        .map(LanguageModelCompletionEvent::Text)
368                        .map_err(LanguageModelCompletionError::Other)
369                })
370                .boxed())
371        }
372        .boxed()
373    }
374}
375
376// This will be more useful when we implement tool calling. Currently keeping it empty.
377struct LmStudioStreamMapper {}
378
379impl LmStudioStreamMapper {
380    fn new() -> Self {
381        Self {}
382    }
383
384    fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
385        // Most of the time, there will be only one choice
386        let Some(choice) = fragment.choices.first() else {
387            return Ok(None);
388        };
389
390        // Extract the delta content
391        if let Ok(delta) =
392            serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
393        {
394            if let Some(content) = delta.content {
395                if !content.is_empty() {
396                    return Ok(Some(content));
397                }
398            }
399        }
400
401        // If there's a finish_reason, we're done
402        if choice.finish_reason.is_some() {
403            return Ok(None);
404        }
405
406        Ok(None)
407    }
408}
409
410struct ConfigurationView {
411    state: gpui::Entity<State>,
412    loading_models_task: Option<Task<()>>,
413}
414
415impl ConfigurationView {
416    pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
417        let loading_models_task = Some(cx.spawn({
418            let state = state.clone();
419            async move |this, cx| {
420                if let Some(task) = state
421                    .update(cx, |state, cx| state.authenticate(cx))
422                    .log_err()
423                {
424                    task.await.log_err();
425                }
426                this.update(cx, |this, cx| {
427                    this.loading_models_task = None;
428                    cx.notify();
429                })
430                .log_err();
431            }
432        }));
433
434        Self {
435            state,
436            loading_models_task,
437        }
438    }
439
440    fn retry_connection(&self, cx: &mut App) {
441        self.state
442            .update(cx, |state, cx| state.fetch_models(cx))
443            .detach_and_log_err(cx);
444    }
445}
446
447impl Render for ConfigurationView {
448    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
449        let is_authenticated = self.state.read(cx).is_authenticated();
450
451        let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
452
453        if self.loading_models_task.is_some() {
454            div().child(Label::new("Loading models...")).into_any()
455        } else {
456            v_flex()
457                .gap_2()
458                .child(
459                    v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
460                        List::new()
461                            .child(InstructionListItem::text_only(
462                                "LM Studio needs to be running with at least one model downloaded.",
463                            ))
464                            .child(InstructionListItem::text_only(
465                                "To get your first model, try running `lms get qwen2.5-coder-7b`",
466                            )),
467                    ),
468                )
469                .child(
470                    h_flex()
471                        .w_full()
472                        .justify_between()
473                        .gap_2()
474                        .child(
475                            h_flex()
476                                .w_full()
477                                .gap_2()
478                                .map(|this| {
479                                    if is_authenticated {
480                                        this.child(
481                                            Button::new("lmstudio-site", "LM Studio")
482                                                .style(ButtonStyle::Subtle)
483                                                .icon(IconName::ArrowUpRight)
484                                                .icon_size(IconSize::XSmall)
485                                                .icon_color(Color::Muted)
486                                                .on_click(move |_, _window, cx| {
487                                                    cx.open_url(LMSTUDIO_SITE)
488                                                })
489                                                .into_any_element(),
490                                        )
491                                    } else {
492                                        this.child(
493                                            Button::new(
494                                                "download_lmstudio_button",
495                                                "Download LM Studio",
496                                            )
497                                            .style(ButtonStyle::Subtle)
498                                            .icon(IconName::ArrowUpRight)
499                                            .icon_size(IconSize::XSmall)
500                                            .icon_color(Color::Muted)
501                                            .on_click(move |_, _window, cx| {
502                                                cx.open_url(LMSTUDIO_DOWNLOAD_URL)
503                                            })
504                                            .into_any_element(),
505                                        )
506                                    }
507                                })
508                                .child(
509                                    Button::new("view-models", "Model Catalog")
510                                        .style(ButtonStyle::Subtle)
511                                        .icon(IconName::ArrowUpRight)
512                                        .icon_size(IconSize::XSmall)
513                                        .icon_color(Color::Muted)
514                                        .on_click(move |_, _window, cx| {
515                                            cx.open_url(LMSTUDIO_CATALOG_URL)
516                                        }),
517                                ),
518                        )
519                        .map(|this| {
520                            if is_authenticated {
521                                this.child(
522                                    ButtonLike::new("connected")
523                                        .disabled(true)
524                                        .cursor_style(gpui::CursorStyle::Arrow)
525                                        .child(
526                                            h_flex()
527                                                .gap_2()
528                                                .child(Indicator::dot().color(Color::Success))
529                                                .child(Label::new("Connected"))
530                                                .into_any_element(),
531                                        ),
532                                )
533                            } else {
534                                this.child(
535                                    Button::new("retry_lmstudio_models", "Connect")
536                                        .icon_position(IconPosition::Start)
537                                        .icon_size(IconSize::XSmall)
538                                        .icon(IconName::Play)
539                                        .on_click(cx.listener(move |this, _, _window, cx| {
540                                            this.retry_connection(cx)
541                                        })),
542                                )
543                            }
544                        }),
545                )
546                .into_any()
547        }
548    }
549}