language_model_selector.rs

  1use std::sync::Arc;
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
  5use gpui::{
  6    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  7    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 11    LanguageModelRegistry,
 12};
 13use picker::{Picker, PickerDelegate};
 14use proto::Plan;
 15use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 16
 17action_with_deprecated_aliases!(
 18    agent,
 19    ToggleModelSelector,
 20    [
 21        "assistant::ToggleModelSelector",
 22        "assistant2::ToggleModelSelector"
 23    ]
 24);
 25
 26const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 27
 28type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 29type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 30
 31pub struct LanguageModelSelector {
 32    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 33    _authenticate_all_providers_task: Task<()>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37impl LanguageModelSelector {
 38    pub fn new(
 39        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 40        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 41        window: &mut Window,
 42        cx: &mut Context<Self>,
 43    ) -> Self {
 44        let on_model_changed = Arc::new(on_model_changed);
 45
 46        let all_models = Self::all_models(cx);
 47        let entries = all_models.entries();
 48
 49        let delegate = LanguageModelPickerDelegate {
 50            language_model_selector: cx.entity().downgrade(),
 51            on_model_changed: on_model_changed.clone(),
 52            all_models: Arc::new(all_models),
 53            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 54            filtered_entries: entries,
 55            get_active_model: Arc::new(get_active_model),
 56        };
 57
 58        let picker = cx.new(|cx| {
 59            Picker::list(delegate, window, cx)
 60                .show_scrollbar(true)
 61                .width(rems(20.))
 62                .max_height(Some(rems(20.).into()))
 63        });
 64
 65        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 66
 67        LanguageModelSelector {
 68            picker,
 69            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 70            _subscriptions: vec![
 71                cx.subscribe_in(
 72                    &LanguageModelRegistry::global(cx),
 73                    window,
 74                    Self::handle_language_model_registry_event,
 75                ),
 76                subscription,
 77            ],
 78        }
 79    }
 80
 81    fn handle_language_model_registry_event(
 82        &mut self,
 83        _registry: &Entity<LanguageModelRegistry>,
 84        event: &language_model::Event,
 85        window: &mut Window,
 86        cx: &mut Context<Self>,
 87    ) {
 88        match event {
 89            language_model::Event::ProviderStateChanged
 90            | language_model::Event::AddedProvider(_)
 91            | language_model::Event::RemovedProvider(_) => {
 92                self.picker.update(cx, |this, cx| {
 93                    let query = this.query(cx);
 94                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 95                    // Update matches will automatically drop the previous task
 96                    // if we get a provider event again
 97                    this.update_matches(query, window, cx)
 98                });
 99            }
100            _ => {}
101        }
102    }
103
104    /// Authenticates all providers in the [`LanguageModelRegistry`].
105    ///
106    /// We do this so that we can populate the language selector with all of the
107    /// models from the configured providers.
108    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
109        let authenticate_all_providers = LanguageModelRegistry::global(cx)
110            .read(cx)
111            .providers()
112            .iter()
113            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
114            .collect::<Vec<_>>();
115
116        cx.spawn(async move |_cx| {
117            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
118                if let Err(err) = authenticate_task.await {
119                    if matches!(err, AuthenticateError::CredentialsNotFound) {
120                        // Since we're authenticating these providers in the
121                        // background for the purposes of populating the
122                        // language selector, we don't care about providers
123                        // where the credentials are not found.
124                    } else {
125                        // Some providers have noisy failure states that we
126                        // don't want to spam the logs with every time the
127                        // language model selector is initialized.
128                        //
129                        // Ideally these should have more clear failure modes
130                        // that we know are safe to ignore here, like what we do
131                        // with `CredentialsNotFound` above.
132                        match provider_id.0.as_ref() {
133                            "lmstudio" | "ollama" => {
134                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
135                                //
136                                // These fail noisily, so we don't log them.
137                            }
138                            "copilot_chat" => {
139                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
140                            }
141                            _ => {
142                                log::error!(
143                                    "Failed to authenticate provider: {}: {err}",
144                                    provider_name.0
145                                );
146                            }
147                        }
148                    }
149                }
150            }
151        })
152    }
153
154    fn all_models(cx: &App) -> GroupedModels {
155        let mut recommended = Vec::new();
156        let mut recommended_set = HashSet::default();
157        for provider in LanguageModelRegistry::global(cx)
158            .read(cx)
159            .providers()
160            .iter()
161        {
162            let models = provider.recommended_models(cx);
163            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
164            recommended.extend(
165                provider
166                    .recommended_models(cx)
167                    .into_iter()
168                    .map(move |model| ModelInfo {
169                        model: model.clone(),
170                        icon: provider.icon(),
171                    }),
172            );
173        }
174
175        let other_models = LanguageModelRegistry::global(cx)
176            .read(cx)
177            .providers()
178            .iter()
179            .map(|provider| {
180                (
181                    provider.id(),
182                    provider
183                        .provided_models(cx)
184                        .into_iter()
185                        .filter_map(|model| {
186                            let not_included =
187                                !recommended_set.contains(&(model.provider_id(), model.id()));
188                            not_included.then(|| ModelInfo {
189                                model: model.clone(),
190                                icon: provider.icon(),
191                            })
192                        })
193                        .collect::<Vec<_>>(),
194                )
195            })
196            .collect::<IndexMap<_, _>>();
197
198        GroupedModels {
199            recommended,
200            other: other_models,
201        }
202    }
203
204    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
205        (self.picker.read(cx).delegate.get_active_model)(cx)
206    }
207
208    fn get_active_model_index(
209        entries: &[LanguageModelPickerEntry],
210        active_model: Option<ConfiguredModel>,
211    ) -> usize {
212        entries
213            .iter()
214            .position(|entry| {
215                if let LanguageModelPickerEntry::Model(model) = entry {
216                    active_model
217                        .as_ref()
218                        .map(|active_model| {
219                            active_model.model.id() == model.model.id()
220                                && active_model.provider.id() == model.model.provider_id()
221                        })
222                        .unwrap_or_default()
223                } else {
224                    false
225                }
226            })
227            .unwrap_or(0)
228    }
229}
230
231impl EventEmitter<DismissEvent> for LanguageModelSelector {}
232
233impl Focusable for LanguageModelSelector {
234    fn focus_handle(&self, cx: &App) -> FocusHandle {
235        self.picker.focus_handle(cx)
236    }
237}
238
239impl Render for LanguageModelSelector {
240    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
241        self.picker.clone()
242    }
243}
244
245#[derive(IntoElement)]
246pub struct LanguageModelSelectorPopoverMenu<T, TT>
247where
248    T: PopoverTrigger + ButtonCommon,
249    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
250{
251    language_model_selector: Entity<LanguageModelSelector>,
252    trigger: T,
253    tooltip: TT,
254    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
255    anchor: Corner,
256}
257
258impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
259where
260    T: PopoverTrigger + ButtonCommon,
261    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
262{
263    pub fn new(
264        language_model_selector: Entity<LanguageModelSelector>,
265        trigger: T,
266        tooltip: TT,
267        anchor: Corner,
268    ) -> Self {
269        Self {
270            language_model_selector,
271            trigger,
272            tooltip,
273            handle: None,
274            anchor,
275        }
276    }
277
278    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
279        self.handle = Some(handle);
280        self
281    }
282}
283
284impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
285where
286    T: PopoverTrigger + ButtonCommon,
287    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
288{
289    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
290        let language_model_selector = self.language_model_selector.clone();
291
292        PopoverMenu::new("model-switcher")
293            .menu(move |_window, _cx| Some(language_model_selector.clone()))
294            .trigger_with_tooltip(self.trigger, self.tooltip)
295            .anchor(self.anchor)
296            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
297            .offset(gpui::Point {
298                x: px(0.0),
299                y: px(-2.0),
300            })
301    }
302}
303
304#[derive(Clone)]
305struct ModelInfo {
306    model: Arc<dyn LanguageModel>,
307    icon: IconName,
308}
309
310pub struct LanguageModelPickerDelegate {
311    language_model_selector: WeakEntity<LanguageModelSelector>,
312    on_model_changed: OnModelChanged,
313    get_active_model: GetActiveModel,
314    all_models: Arc<GroupedModels>,
315    filtered_entries: Vec<LanguageModelPickerEntry>,
316    selected_index: usize,
317}
318
319struct GroupedModels {
320    recommended: Vec<ModelInfo>,
321    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
322}
323
324impl GroupedModels {
325    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
326        let mut entries = Vec::new();
327
328        if !self.recommended.is_empty() {
329            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
330            entries.extend(
331                self.recommended
332                    .iter()
333                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
334            );
335        }
336
337        for models in self.other.values() {
338            if models.is_empty() {
339                continue;
340            }
341            entries.push(LanguageModelPickerEntry::Separator(
342                models[0].model.provider_name().0,
343            ));
344            entries.extend(
345                models
346                    .iter()
347                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
348            );
349        }
350        entries
351    }
352}
353
354enum LanguageModelPickerEntry {
355    Model(ModelInfo),
356    Separator(SharedString),
357}
358
359impl PickerDelegate for LanguageModelPickerDelegate {
360    type ListItem = AnyElement;
361
362    fn match_count(&self) -> usize {
363        self.filtered_entries.len()
364    }
365
366    fn selected_index(&self) -> usize {
367        self.selected_index
368    }
369
370    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
371        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
372        cx.notify();
373    }
374
375    fn can_select(
376        &mut self,
377        ix: usize,
378        _window: &mut Window,
379        _cx: &mut Context<Picker<Self>>,
380    ) -> bool {
381        match self.filtered_entries.get(ix) {
382            Some(LanguageModelPickerEntry::Model(_)) => true,
383            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
384        }
385    }
386
387    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
388        "Select a model…".into()
389    }
390
391    fn update_matches(
392        &mut self,
393        query: String,
394        window: &mut Window,
395        cx: &mut Context<Picker<Self>>,
396    ) -> Task<()> {
397        let all_models = self.all_models.clone();
398        let current_index = self.selected_index;
399
400        let language_model_registry = LanguageModelRegistry::global(cx);
401
402        let configured_providers = language_model_registry
403            .read(cx)
404            .providers()
405            .iter()
406            .filter(|provider| provider.is_authenticated(cx))
407            .map(|provider| provider.id())
408            .collect::<Vec<_>>();
409
410        cx.spawn_in(window, async move |this, cx| {
411            let filtered_models = cx
412                .background_spawn(async move {
413                    let matches = |info: &ModelInfo| {
414                        info.model
415                            .name()
416                            .0
417                            .to_lowercase()
418                            .contains(&query.to_lowercase())
419                    };
420
421                    let recommended_models = all_models
422                        .recommended
423                        .iter()
424                        .filter(|r| {
425                            configured_providers.contains(&r.model.provider_id()) && matches(r)
426                        })
427                        .cloned()
428                        .collect();
429                    let mut other_models = IndexMap::default();
430                    for (provider_id, models) in &all_models.other {
431                        if configured_providers.contains(&provider_id) {
432                            other_models.insert(
433                                provider_id.clone(),
434                                models
435                                    .iter()
436                                    .filter(|m| matches(m))
437                                    .cloned()
438                                    .collect::<Vec<_>>(),
439                            );
440                        }
441                    }
442                    GroupedModels {
443                        recommended: recommended_models,
444                        other: other_models,
445                    }
446                })
447                .await;
448
449            this.update_in(cx, |this, window, cx| {
450                this.delegate.filtered_entries = filtered_models.entries();
451                // Preserve selection focus
452                let new_index = if current_index >= this.delegate.filtered_entries.len() {
453                    0
454                } else {
455                    current_index
456                };
457                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
458                cx.notify();
459            })
460            .ok();
461        })
462    }
463
464    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
465        if let Some(LanguageModelPickerEntry::Model(model_info)) =
466            self.filtered_entries.get(self.selected_index)
467        {
468            let model = model_info.model.clone();
469            (self.on_model_changed)(model.clone(), cx);
470
471            let current_index = self.selected_index;
472            self.set_selected_index(current_index, window, cx);
473
474            cx.emit(DismissEvent);
475        }
476    }
477
478    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
479        self.language_model_selector
480            .update(cx, |_this, cx| cx.emit(DismissEvent))
481            .ok();
482    }
483
484    fn render_match(
485        &self,
486        ix: usize,
487        selected: bool,
488        _: &mut Window,
489        cx: &mut Context<Picker<Self>>,
490    ) -> Option<Self::ListItem> {
491        match self.filtered_entries.get(ix)? {
492            LanguageModelPickerEntry::Separator(title) => Some(
493                div()
494                    .px_2()
495                    .pb_1()
496                    .when(ix > 1, |this| {
497                        this.mt_1()
498                            .pt_2()
499                            .border_t_1()
500                            .border_color(cx.theme().colors().border_variant)
501                    })
502                    .child(
503                        Label::new(title)
504                            .size(LabelSize::XSmall)
505                            .color(Color::Muted),
506                    )
507                    .into_any_element(),
508            ),
509            LanguageModelPickerEntry::Model(model_info) => {
510                let active_model = (self.get_active_model)(cx);
511                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
512                let active_model_id = active_model.map(|m| m.model.id());
513
514                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
515                    && Some(model_info.model.id()) == active_model_id;
516
517                let model_icon_color = if is_selected {
518                    Color::Accent
519                } else {
520                    Color::Muted
521                };
522
523                Some(
524                    ListItem::new(ix)
525                        .inset(true)
526                        .spacing(ListItemSpacing::Sparse)
527                        .toggle_state(selected)
528                        .start_slot(
529                            Icon::new(model_info.icon)
530                                .color(model_icon_color)
531                                .size(IconSize::Small),
532                        )
533                        .child(
534                            h_flex()
535                                .w_full()
536                                .pl_0p5()
537                                .gap_1p5()
538                                .w(px(240.))
539                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
540                        )
541                        .end_slot(div().pr_3().when(is_selected, |this| {
542                            this.child(
543                                Icon::new(IconName::Check)
544                                    .color(Color::Accent)
545                                    .size(IconSize::Small),
546                            )
547                        }))
548                        .into_any_element(),
549                )
550            }
551        }
552    }
553
554    fn render_footer(
555        &self,
556        _: &mut Window,
557        cx: &mut Context<Picker<Self>>,
558    ) -> Option<gpui::AnyElement> {
559        use feature_flags::FeatureFlagAppExt;
560
561        let plan = proto::Plan::ZedPro;
562
563        Some(
564            h_flex()
565                .w_full()
566                .border_t_1()
567                .border_color(cx.theme().colors().border_variant)
568                .p_1()
569                .gap_4()
570                .justify_between()
571                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
572                    this.child(match plan {
573                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
574                            .icon(IconName::ZedAssistant)
575                            .icon_size(IconSize::Small)
576                            .icon_color(Color::Muted)
577                            .icon_position(IconPosition::Start)
578                            .on_click(|_, window, cx| {
579                                window
580                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
581                            }),
582                        Plan::Free | Plan::ZedProTrial => Button::new(
583                            "try-pro",
584                            if plan == Plan::ZedProTrial {
585                                "Upgrade to Pro"
586                            } else {
587                                "Try Pro"
588                            },
589                        )
590                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
591                    })
592                })
593                .child(
594                    Button::new("configure", "Configure")
595                        .icon(IconName::Settings)
596                        .icon_size(IconSize::Small)
597                        .icon_color(Color::Muted)
598                        .icon_position(IconPosition::Start)
599                        .on_click(|_, window, cx| {
600                            let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
601                                zed_actions::agent::OpenConfiguration.boxed_clone()
602                            } else {
603                                zed_actions::assistant::ShowConfiguration.boxed_clone()
604                            };
605
606                            window.dispatch_action(configure_action, cx);
607                        }),
608                )
609                .into_any(),
610        )
611    }
612}