language_model_selector.rs

  1use std::{rc::Rc, sync::Arc};
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    action_with_deprecated_aliases, Action, AnyElement, App, Corner, DismissEvent, Entity,
  6    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
  7};
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
 10};
 11use picker::{Picker, PickerDelegate};
 12use proto::Plan;
 13use ui::{
 14    prelude::*, ButtonLike, IconButtonShape, ListItem, ListItemSpacing, PopoverMenu,
 15    PopoverMenuHandle, Tooltip,
 16};
 17use workspace::ShowConfiguration;
 18
 19action_with_deprecated_aliases!(
 20    assistant,
 21    ToggleModelSelector,
 22    ["assistant2::ToggleModelSelector"]
 23);
 24
 25const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 26
 27type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 28
 29pub struct LanguageModelSelector {
 30    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 31    /// The task used to update the picker's matches when there is a change to
 32    /// the language model registry.
 33    update_matches_task: Option<Task<()>>,
 34    popover_menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 35    _authenticate_all_providers_task: Task<()>,
 36    _subscriptions: Vec<Subscription>,
 37}
 38
 39impl LanguageModelSelector {
 40    pub fn new(
 41        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 42        window: &mut Window,
 43        cx: &mut Context<Self>,
 44    ) -> Self {
 45        let on_model_changed = Arc::new(on_model_changed);
 46
 47        let all_models = Self::all_models(cx);
 48        let delegate = LanguageModelPickerDelegate {
 49            language_model_selector: cx.entity().downgrade(),
 50            on_model_changed: on_model_changed.clone(),
 51            all_models: all_models.clone(),
 52            filtered_models: all_models,
 53            selected_index: Self::get_active_model_index(cx),
 54        };
 55
 56        let picker = cx.new(|cx| {
 57            Picker::uniform_list(delegate, window, cx)
 58                .show_scrollbar(true)
 59                .width(rems(20.))
 60                .max_height(Some(rems(20.).into()))
 61        });
 62
 63        LanguageModelSelector {
 64            picker,
 65            update_matches_task: None,
 66            popover_menu_handle: PopoverMenuHandle::default(),
 67            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 68            _subscriptions: vec![cx.subscribe_in(
 69                &LanguageModelRegistry::global(cx),
 70                window,
 71                Self::handle_language_model_registry_event,
 72            )],
 73        }
 74    }
 75
 76    pub fn toggle_model_selector(
 77        &mut self,
 78        _: &ToggleModelSelector,
 79        window: &mut Window,
 80        cx: &mut Context<Self>,
 81    ) {
 82        self.popover_menu_handle.toggle(window, cx);
 83    }
 84
 85    fn handle_language_model_registry_event(
 86        &mut self,
 87        _registry: &Entity<LanguageModelRegistry>,
 88        event: &language_model::Event,
 89        window: &mut Window,
 90        cx: &mut Context<Self>,
 91    ) {
 92        match event {
 93            language_model::Event::ProviderStateChanged
 94            | language_model::Event::AddedProvider(_)
 95            | language_model::Event::RemovedProvider(_) => {
 96                let task = self.picker.update(cx, |this, cx| {
 97                    let query = this.query(cx);
 98                    this.delegate.all_models = Self::all_models(cx);
 99                    this.delegate.update_matches(query, window, cx)
100                });
101                self.update_matches_task = Some(task);
102            }
103            _ => {}
104        }
105    }
106
107    /// Authenticates all providers in the [`LanguageModelRegistry`].
108    ///
109    /// We do this so that we can populate the language selector with all of the
110    /// models from the configured providers.
111    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112        let authenticate_all_providers = LanguageModelRegistry::global(cx)
113            .read(cx)
114            .providers()
115            .iter()
116            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117            .collect::<Vec<_>>();
118
119        cx.spawn(|_cx| async move {
120            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121                if let Err(err) = authenticate_task.await {
122                    if matches!(err, AuthenticateError::CredentialsNotFound) {
123                        // Since we're authenticating these providers in the
124                        // background for the purposes of populating the
125                        // language selector, we don't care about providers
126                        // where the credentials are not found.
127                    } else {
128                        // Some providers have noisy failure states that we
129                        // don't want to spam the logs with every time the
130                        // language model selector is initialized.
131                        //
132                        // Ideally these should have more clear failure modes
133                        // that we know are safe to ignore here, like what we do
134                        // with `CredentialsNotFound` above.
135                        match provider_id.0.as_ref() {
136                            "lmstudio" | "ollama" => {
137                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138                                //
139                                // These fail noisily, so we don't log them.
140                            }
141                            "copilot_chat" => {
142                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143                            }
144                            _ => {
145                                log::error!(
146                                    "Failed to authenticate provider: {}: {err}",
147                                    provider_name.0
148                                );
149                            }
150                        }
151                    }
152                }
153            }
154        })
155    }
156
157    fn all_models(cx: &App) -> Vec<ModelInfo> {
158        LanguageModelRegistry::global(cx)
159            .read(cx)
160            .providers()
161            .iter()
162            .flat_map(|provider| {
163                let icon = provider.icon();
164
165                provider.provided_models(cx).into_iter().map(move |model| {
166                    let model = model.clone();
167                    let icon = model.icon().unwrap_or(icon);
168
169                    ModelInfo {
170                        model: model.clone(),
171                        icon,
172                        availability: model.availability(),
173                    }
174                })
175            })
176            .collect::<Vec<_>>()
177    }
178
179    fn get_active_model_index(cx: &App) -> usize {
180        let active_model = LanguageModelRegistry::read_global(cx).active_model();
181        Self::all_models(cx)
182            .iter()
183            .position(|model_info| {
184                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
185            })
186            .unwrap_or(0)
187    }
188}
189
190impl EventEmitter<DismissEvent> for LanguageModelSelector {}
191
192impl Focusable for LanguageModelSelector {
193    fn focus_handle(&self, cx: &App) -> FocusHandle {
194        self.picker.focus_handle(cx)
195    }
196}
197
198impl Render for LanguageModelSelector {
199    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
200        self.picker.clone()
201    }
202}
203
204#[derive(Clone)]
205struct ModelInfo {
206    model: Arc<dyn LanguageModel>,
207    icon: IconName,
208    availability: LanguageModelAvailability,
209}
210
211pub struct LanguageModelPickerDelegate {
212    language_model_selector: WeakEntity<LanguageModelSelector>,
213    on_model_changed: OnModelChanged,
214    all_models: Vec<ModelInfo>,
215    filtered_models: Vec<ModelInfo>,
216    selected_index: usize,
217}
218
219impl PickerDelegate for LanguageModelPickerDelegate {
220    type ListItem = ListItem;
221
222    fn match_count(&self) -> usize {
223        self.filtered_models.len()
224    }
225
226    fn selected_index(&self) -> usize {
227        self.selected_index
228    }
229
230    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
231        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
232        cx.notify();
233    }
234
235    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
236        "Select a model...".into()
237    }
238
239    fn update_matches(
240        &mut self,
241        query: String,
242        window: &mut Window,
243        cx: &mut Context<Picker<Self>>,
244    ) -> Task<()> {
245        let all_models = self.all_models.clone();
246        let current_index = self.selected_index;
247
248        let language_model_registry = LanguageModelRegistry::global(cx);
249
250        let configured_providers = language_model_registry
251            .read(cx)
252            .providers()
253            .iter()
254            .filter(|provider| provider.is_authenticated(cx))
255            .map(|provider| provider.id())
256            .collect::<Vec<_>>();
257
258        cx.spawn_in(window, |this, mut cx| async move {
259            let filtered_models = cx
260                .background_spawn(async move {
261                    let displayed_models = if configured_providers.is_empty() {
262                        all_models
263                    } else {
264                        all_models
265                            .into_iter()
266                            .filter(|model_info| {
267                                configured_providers.contains(&model_info.model.provider_id())
268                            })
269                            .collect::<Vec<_>>()
270                    };
271
272                    if query.is_empty() {
273                        displayed_models
274                    } else {
275                        displayed_models
276                            .into_iter()
277                            .filter(|model_info| {
278                                model_info
279                                    .model
280                                    .name()
281                                    .0
282                                    .to_lowercase()
283                                    .contains(&query.to_lowercase())
284                            })
285                            .collect()
286                    }
287                })
288                .await;
289
290            this.update_in(&mut cx, |this, window, cx| {
291                this.delegate.filtered_models = filtered_models;
292                // Preserve selection focus
293                let new_index = if current_index >= this.delegate.filtered_models.len() {
294                    0
295                } else {
296                    current_index
297                };
298                this.delegate.set_selected_index(new_index, window, cx);
299                cx.notify();
300            })
301            .ok();
302        })
303    }
304
305    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
306        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
307            let model = model_info.model.clone();
308            (self.on_model_changed)(model.clone(), cx);
309
310            let current_index = self.selected_index;
311            self.set_selected_index(current_index, window, cx);
312
313            cx.emit(DismissEvent);
314        }
315    }
316
317    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
318        self.language_model_selector
319            .update(cx, |_this, cx| cx.emit(DismissEvent))
320            .ok();
321    }
322
323    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
324        let configured_models_count = LanguageModelRegistry::global(cx)
325            .read(cx)
326            .providers()
327            .iter()
328            .filter(|provider| provider.is_authenticated(cx))
329            .count();
330
331        if configured_models_count > 0 {
332            Some(
333                Label::new("Configured Models")
334                    .size(LabelSize::Small)
335                    .color(Color::Muted)
336                    .mt_1()
337                    .mb_0p5()
338                    .ml_2()
339                    .into_any_element(),
340            )
341        } else {
342            None
343        }
344    }
345
346    fn render_match(
347        &self,
348        ix: usize,
349        selected: bool,
350        _: &mut Window,
351        cx: &mut Context<Picker<Self>>,
352    ) -> Option<Self::ListItem> {
353        use feature_flags::FeatureFlagAppExt;
354        let show_badges = cx.has_flag::<ZedPro>();
355
356        let model_info = self.filtered_models.get(ix)?;
357        let provider_name: String = model_info.model.provider_name().0.clone().into();
358
359        let active_provider_id = LanguageModelRegistry::read_global(cx)
360            .active_provider()
361            .map(|m| m.id());
362
363        let active_model_id = LanguageModelRegistry::read_global(cx)
364            .active_model()
365            .map(|m| m.id());
366
367        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
368            && Some(model_info.model.id()) == active_model_id;
369
370        let model_icon_color = if is_selected {
371            Color::Accent
372        } else {
373            Color::Muted
374        };
375
376        Some(
377            ListItem::new(ix)
378                .inset(true)
379                .spacing(ListItemSpacing::Sparse)
380                .toggle_state(selected)
381                .start_slot(
382                    Icon::new(model_info.icon)
383                        .color(model_icon_color)
384                        .size(IconSize::Small),
385                )
386                .child(
387                    h_flex()
388                        .w_full()
389                        .items_center()
390                        .gap_1p5()
391                        .pl_0p5()
392                        .w(px(240.))
393                        .child(
394                            div()
395                                .max_w_40()
396                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
397                        )
398                        .child(
399                            h_flex()
400                                .gap_0p5()
401                                .child(
402                                    Label::new(provider_name)
403                                        .size(LabelSize::XSmall)
404                                        .color(Color::Muted),
405                                )
406                                .children(match model_info.availability {
407                                    LanguageModelAvailability::Public => None,
408                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
409                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
410                                        show_badges.then(|| {
411                                            Label::new("Pro")
412                                                .size(LabelSize::XSmall)
413                                                .color(Color::Muted)
414                                        })
415                                    }
416                                }),
417                        ),
418                )
419                .end_slot(div().pr_3().when(is_selected, |this| {
420                    this.child(
421                        Icon::new(IconName::Check)
422                            .color(Color::Accent)
423                            .size(IconSize::Small),
424                    )
425                })),
426        )
427    }
428
429    fn render_footer(
430        &self,
431        _: &mut Window,
432        cx: &mut Context<Picker<Self>>,
433    ) -> Option<gpui::AnyElement> {
434        use feature_flags::FeatureFlagAppExt;
435
436        let plan = proto::Plan::ZedPro;
437        let is_trial = false;
438
439        Some(
440            h_flex()
441                .w_full()
442                .border_t_1()
443                .border_color(cx.theme().colors().border_variant)
444                .p_1()
445                .gap_4()
446                .justify_between()
447                .when(cx.has_flag::<ZedPro>(), |this| {
448                    this.child(match plan {
449                        // Already a Zed Pro subscriber
450                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
451                            .icon(IconName::ZedAssistant)
452                            .icon_size(IconSize::Small)
453                            .icon_color(Color::Muted)
454                            .icon_position(IconPosition::Start)
455                            .on_click(|_, window, cx| {
456                                window
457                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
458                            }),
459                        // Free user
460                        Plan::Free => Button::new(
461                            "try-pro",
462                            if is_trial {
463                                "Upgrade to Pro"
464                            } else {
465                                "Try Pro"
466                            },
467                        )
468                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
469                    })
470                })
471                .child(
472                    Button::new("configure", "Configure")
473                        .icon(IconName::Settings)
474                        .icon_size(IconSize::Small)
475                        .icon_color(Color::Muted)
476                        .icon_position(IconPosition::Start)
477                        .on_click(|_, window, cx| {
478                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
479                        }),
480                )
481                .into_any(),
482        )
483    }
484}
485
486pub fn inline_language_model_selector(
487    f: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
488) -> AnyElement {
489    let f = Rc::new(f);
490    PopoverMenu::new("popover-button")
491        .menu(move |window, cx| {
492            Some(cx.new(|cx| {
493                LanguageModelSelector::new(
494                    {
495                        let f = f.clone();
496                        move |model, cx| {
497                            f(model, cx);
498                        }
499                    },
500                    window,
501                    cx,
502                )
503            }))
504        })
505        .trigger_with_tooltip(
506            IconButton::new("context", IconName::SettingsAlt)
507                .shape(IconButtonShape::Square)
508                .icon_size(IconSize::Small)
509                .icon_color(Color::Muted),
510            move |window, cx| {
511                Tooltip::with_meta(
512                    format!(
513                        "Using {}",
514                        LanguageModelRegistry::read_global(cx)
515                            .active_model()
516                            .map(|model| model.name().0)
517                            .unwrap_or_else(|| "No model selected".into()),
518                    ),
519                    None,
520                    "Change Model",
521                    window,
522                    cx,
523                )
524            },
525        )
526        .anchor(gpui::Corner::TopRight)
527        // .when_some(menu_handle, |el, handle| el.with_handle(handle))
528        .offset(gpui::Point {
529            x: px(0.0),
530            y: px(-2.0),
531        })
532        .into_any_element()
533}
534
535pub fn assistant_language_model_selector(
536    keybinding_target: FocusHandle,
537    menu_handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
538    cx: &App,
539    f: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
540) -> AnyElement {
541    let active_model = LanguageModelRegistry::read_global(cx).active_model();
542    let model_name = match active_model {
543        Some(model) => model.name().0,
544        _ => SharedString::from("No model selected"),
545    };
546
547    let f = Rc::new(f);
548
549    PopoverMenu::new("popover-button")
550        .menu(move |window, cx| {
551            Some(cx.new(|cx| {
552                LanguageModelSelector::new(
553                    {
554                        let f = f.clone();
555                        move |model, cx| {
556                            f(model, cx);
557                        }
558                    },
559                    window,
560                    cx,
561                )
562            }))
563        })
564        .trigger_with_tooltip(
565            ButtonLike::new("active-model")
566                .style(ButtonStyle::Subtle)
567                .child(
568                    h_flex()
569                        .gap_0p5()
570                        .child(
571                            Label::new(model_name)
572                                .size(LabelSize::Small)
573                                .color(Color::Muted),
574                        )
575                        .child(
576                            Icon::new(IconName::ChevronDown)
577                                .color(Color::Muted)
578                                .size(IconSize::XSmall),
579                        ),
580                ),
581            move |window, cx| {
582                Tooltip::for_action_in(
583                    "Change Model",
584                    &ToggleModelSelector,
585                    &keybinding_target,
586                    window,
587                    cx,
588                )
589            },
590        )
591        .anchor(Corner::BottomRight)
592        .when_some(menu_handle, |el, handle| el.with_handle(handle))
593        .offset(gpui::Point {
594            x: px(0.0),
595            y: px(-2.0),
596        })
597        .into_any_element()
598}