language_model_selector.rs

  1use std::sync::Arc;
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
  5use gpui::{
  6    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  7    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 11    LanguageModelRegistry,
 12};
 13use picker::{Picker, PickerDelegate};
 14use proto::Plan;
 15use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 16
 17action_with_deprecated_aliases!(
 18    assistant,
 19    ToggleModelSelector,
 20    ["assistant2::ToggleModelSelector"]
 21);
 22
 23const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 24
 25type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 26
 27pub struct LanguageModelSelector {
 28    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 29    _authenticate_all_providers_task: Task<()>,
 30    _subscriptions: Vec<Subscription>,
 31}
 32
 33#[derive(Clone, Copy)]
 34pub enum ModelType {
 35    Default,
 36    InlineAssistant,
 37}
 38
 39impl LanguageModelSelector {
 40    pub fn new(
 41        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 42        model_type: ModelType,
 43        window: &mut Window,
 44        cx: &mut Context<Self>,
 45    ) -> Self {
 46        let on_model_changed = Arc::new(on_model_changed);
 47
 48        let all_models = Self::all_models(cx);
 49        let entries = all_models.entries();
 50
 51        let delegate = LanguageModelPickerDelegate {
 52            language_model_selector: cx.entity().downgrade(),
 53            on_model_changed: on_model_changed.clone(),
 54            all_models: Arc::new(all_models),
 55            selected_index: Self::get_active_model_index(&entries, model_type, cx),
 56            filtered_entries: entries,
 57            model_type,
 58        };
 59
 60        let picker = cx.new(|cx| {
 61            Picker::list(delegate, window, cx)
 62                .show_scrollbar(true)
 63                .width(rems(20.))
 64                .max_height(Some(rems(20.).into()))
 65        });
 66
 67        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 68
 69        LanguageModelSelector {
 70            picker,
 71            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 72            _subscriptions: vec![
 73                cx.subscribe_in(
 74                    &LanguageModelRegistry::global(cx),
 75                    window,
 76                    Self::handle_language_model_registry_event,
 77                ),
 78                subscription,
 79            ],
 80        }
 81    }
 82
 83    fn handle_language_model_registry_event(
 84        &mut self,
 85        _registry: &Entity<LanguageModelRegistry>,
 86        event: &language_model::Event,
 87        window: &mut Window,
 88        cx: &mut Context<Self>,
 89    ) {
 90        match event {
 91            language_model::Event::ProviderStateChanged
 92            | language_model::Event::AddedProvider(_)
 93            | language_model::Event::RemovedProvider(_) => {
 94                self.picker.update(cx, |this, cx| {
 95                    let query = this.query(cx);
 96                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 97                    // Update matches will automatically drop the previous task
 98                    // if we get a provider event again
 99                    this.update_matches(query, window, cx)
100                });
101            }
102            _ => {}
103        }
104    }
105
106    /// Authenticates all providers in the [`LanguageModelRegistry`].
107    ///
108    /// We do this so that we can populate the language selector with all of the
109    /// models from the configured providers.
110    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
111        let authenticate_all_providers = LanguageModelRegistry::global(cx)
112            .read(cx)
113            .providers()
114            .iter()
115            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
116            .collect::<Vec<_>>();
117
118        cx.spawn(async move |_cx| {
119            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
120                if let Err(err) = authenticate_task.await {
121                    if matches!(err, AuthenticateError::CredentialsNotFound) {
122                        // Since we're authenticating these providers in the
123                        // background for the purposes of populating the
124                        // language selector, we don't care about providers
125                        // where the credentials are not found.
126                    } else {
127                        // Some providers have noisy failure states that we
128                        // don't want to spam the logs with every time the
129                        // language model selector is initialized.
130                        //
131                        // Ideally these should have more clear failure modes
132                        // that we know are safe to ignore here, like what we do
133                        // with `CredentialsNotFound` above.
134                        match provider_id.0.as_ref() {
135                            "lmstudio" | "ollama" => {
136                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
137                                //
138                                // These fail noisily, so we don't log them.
139                            }
140                            "copilot_chat" => {
141                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
142                            }
143                            _ => {
144                                log::error!(
145                                    "Failed to authenticate provider: {}: {err}",
146                                    provider_name.0
147                                );
148                            }
149                        }
150                    }
151                }
152            }
153        })
154    }
155
156    fn all_models(cx: &App) -> GroupedModels {
157        let mut recommended = Vec::new();
158        let mut recommended_set = HashSet::default();
159        for provider in LanguageModelRegistry::global(cx)
160            .read(cx)
161            .providers()
162            .iter()
163        {
164            let models = provider.recommended_models(cx);
165            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
166            recommended.extend(
167                provider
168                    .recommended_models(cx)
169                    .into_iter()
170                    .map(move |model| ModelInfo {
171                        model: model.clone(),
172                        icon: provider.icon(),
173                    }),
174            );
175        }
176
177        let other_models = LanguageModelRegistry::global(cx)
178            .read(cx)
179            .providers()
180            .iter()
181            .map(|provider| {
182                (
183                    provider.id(),
184                    provider
185                        .provided_models(cx)
186                        .into_iter()
187                        .filter_map(|model| {
188                            let not_included =
189                                !recommended_set.contains(&(model.provider_id(), model.id()));
190                            not_included.then(|| ModelInfo {
191                                model: model.clone(),
192                                icon: provider.icon(),
193                            })
194                        })
195                        .collect::<Vec<_>>(),
196                )
197            })
198            .collect::<IndexMap<_, _>>();
199
200        GroupedModels {
201            recommended,
202            other: other_models,
203        }
204    }
205
206    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
207        let model_type = self.picker.read(cx).delegate.model_type;
208        Self::active_model_by_type(model_type, cx)
209    }
210
211    fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
212        match model_type {
213            ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
214            ModelType::InlineAssistant => {
215                LanguageModelRegistry::read_global(cx).inline_assistant_model()
216            }
217        }
218    }
219
220    fn get_active_model_index(
221        entries: &[LanguageModelPickerEntry],
222        model_type: ModelType,
223        cx: &App,
224    ) -> usize {
225        let active_model = Self::active_model_by_type(model_type, cx);
226
227        entries
228            .iter()
229            .position(|entry| {
230                if let LanguageModelPickerEntry::Model(model) = entry {
231                    active_model
232                        .as_ref()
233                        .map(|active_model| {
234                            active_model.model.id() == model.model.id()
235                                && active_model.model.provider_id() == model.model.provider_id()
236                        })
237                        .unwrap_or_default()
238                } else {
239                    false
240                }
241            })
242            .unwrap_or(0)
243    }
244}
245
246impl EventEmitter<DismissEvent> for LanguageModelSelector {}
247
248impl Focusable for LanguageModelSelector {
249    fn focus_handle(&self, cx: &App) -> FocusHandle {
250        self.picker.focus_handle(cx)
251    }
252}
253
254impl Render for LanguageModelSelector {
255    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
256        self.picker.clone()
257    }
258}
259
260#[derive(IntoElement)]
261pub struct LanguageModelSelectorPopoverMenu<T, TT>
262where
263    T: PopoverTrigger + ButtonCommon,
264    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
265{
266    language_model_selector: Entity<LanguageModelSelector>,
267    trigger: T,
268    tooltip: TT,
269    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
270    anchor: Corner,
271}
272
273impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
274where
275    T: PopoverTrigger + ButtonCommon,
276    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
277{
278    pub fn new(
279        language_model_selector: Entity<LanguageModelSelector>,
280        trigger: T,
281        tooltip: TT,
282        anchor: Corner,
283    ) -> Self {
284        Self {
285            language_model_selector,
286            trigger,
287            tooltip,
288            handle: None,
289            anchor,
290        }
291    }
292
293    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
294        self.handle = Some(handle);
295        self
296    }
297}
298
299impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
300where
301    T: PopoverTrigger + ButtonCommon,
302    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
303{
304    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
305        let language_model_selector = self.language_model_selector.clone();
306
307        PopoverMenu::new("model-switcher")
308            .menu(move |_window, _cx| Some(language_model_selector.clone()))
309            .trigger_with_tooltip(self.trigger, self.tooltip)
310            .anchor(self.anchor)
311            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
312            .offset(gpui::Point {
313                x: px(0.0),
314                y: px(-2.0),
315            })
316    }
317}
318
319#[derive(Clone)]
320struct ModelInfo {
321    model: Arc<dyn LanguageModel>,
322    icon: IconName,
323}
324
325pub struct LanguageModelPickerDelegate {
326    language_model_selector: WeakEntity<LanguageModelSelector>,
327    on_model_changed: OnModelChanged,
328    all_models: Arc<GroupedModels>,
329    filtered_entries: Vec<LanguageModelPickerEntry>,
330    selected_index: usize,
331    model_type: ModelType,
332}
333
334struct GroupedModels {
335    recommended: Vec<ModelInfo>,
336    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
337}
338
339impl GroupedModels {
340    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
341        let mut entries = Vec::new();
342
343        if !self.recommended.is_empty() {
344            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
345            entries.extend(
346                self.recommended
347                    .iter()
348                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
349            );
350        }
351
352        for models in self.other.values() {
353            if models.is_empty() {
354                continue;
355            }
356            entries.push(LanguageModelPickerEntry::Separator(
357                models[0].model.provider_name().0,
358            ));
359            entries.extend(
360                models
361                    .iter()
362                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
363            );
364        }
365        entries
366    }
367}
368
369enum LanguageModelPickerEntry {
370    Model(ModelInfo),
371    Separator(SharedString),
372}
373
374impl PickerDelegate for LanguageModelPickerDelegate {
375    type ListItem = AnyElement;
376
377    fn match_count(&self) -> usize {
378        self.filtered_entries.len()
379    }
380
381    fn selected_index(&self) -> usize {
382        self.selected_index
383    }
384
385    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
386        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
387        cx.notify();
388    }
389
390    fn can_select(
391        &mut self,
392        ix: usize,
393        _window: &mut Window,
394        _cx: &mut Context<Picker<Self>>,
395    ) -> bool {
396        match self.filtered_entries.get(ix) {
397            Some(LanguageModelPickerEntry::Model(_)) => true,
398            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
399        }
400    }
401
402    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
403        "Select a model…".into()
404    }
405
406    fn update_matches(
407        &mut self,
408        query: String,
409        window: &mut Window,
410        cx: &mut Context<Picker<Self>>,
411    ) -> Task<()> {
412        let all_models = self.all_models.clone();
413        let current_index = self.selected_index;
414
415        let language_model_registry = LanguageModelRegistry::global(cx);
416
417        let configured_providers = language_model_registry
418            .read(cx)
419            .providers()
420            .iter()
421            .filter(|provider| provider.is_authenticated(cx))
422            .map(|provider| provider.id())
423            .collect::<Vec<_>>();
424
425        cx.spawn_in(window, async move |this, cx| {
426            let filtered_models = cx
427                .background_spawn(async move {
428                    let matches = |info: &ModelInfo| {
429                        info.model
430                            .name()
431                            .0
432                            .to_lowercase()
433                            .contains(&query.to_lowercase())
434                    };
435
436                    let recommended_models = all_models
437                        .recommended
438                        .iter()
439                        .filter(|r| {
440                            configured_providers.contains(&r.model.provider_id()) && matches(r)
441                        })
442                        .cloned()
443                        .collect();
444                    let mut other_models = IndexMap::default();
445                    for (provider_id, models) in &all_models.other {
446                        if configured_providers.contains(&provider_id) {
447                            other_models.insert(
448                                provider_id.clone(),
449                                models
450                                    .iter()
451                                    .filter(|m| matches(m))
452                                    .cloned()
453                                    .collect::<Vec<_>>(),
454                            );
455                        }
456                    }
457                    GroupedModels {
458                        recommended: recommended_models,
459                        other: other_models,
460                    }
461                })
462                .await;
463
464            this.update_in(cx, |this, window, cx| {
465                this.delegate.filtered_entries = filtered_models.entries();
466                // Preserve selection focus
467                let new_index = if current_index >= this.delegate.filtered_entries.len() {
468                    0
469                } else {
470                    current_index
471                };
472                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
473                cx.notify();
474            })
475            .ok();
476        })
477    }
478
479    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
480        if let Some(LanguageModelPickerEntry::Model(model_info)) =
481            self.filtered_entries.get(self.selected_index)
482        {
483            let model = model_info.model.clone();
484            (self.on_model_changed)(model.clone(), cx);
485
486            let current_index = self.selected_index;
487            self.set_selected_index(current_index, window, cx);
488
489            cx.emit(DismissEvent);
490        }
491    }
492
493    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
494        self.language_model_selector
495            .update(cx, |_this, cx| cx.emit(DismissEvent))
496            .ok();
497    }
498
499    fn render_match(
500        &self,
501        ix: usize,
502        selected: bool,
503        _: &mut Window,
504        cx: &mut Context<Picker<Self>>,
505    ) -> Option<Self::ListItem> {
506        match self.filtered_entries.get(ix)? {
507            LanguageModelPickerEntry::Separator(title) => Some(
508                div()
509                    .px_2()
510                    .pb_1()
511                    .when(ix > 1, |this| {
512                        this.mt_1()
513                            .pt_2()
514                            .border_t_1()
515                            .border_color(cx.theme().colors().border_variant)
516                    })
517                    .child(
518                        Label::new(title)
519                            .size(LabelSize::XSmall)
520                            .color(Color::Muted),
521                    )
522                    .into_any_element(),
523            ),
524            LanguageModelPickerEntry::Model(model_info) => {
525                let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
526
527                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
528                let active_model_id = active_model.map(|m| m.model.id());
529
530                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
531                    && Some(model_info.model.id()) == active_model_id;
532
533                let model_icon_color = if is_selected {
534                    Color::Accent
535                } else {
536                    Color::Muted
537                };
538
539                Some(
540                    ListItem::new(ix)
541                        .inset(true)
542                        .spacing(ListItemSpacing::Sparse)
543                        .toggle_state(selected)
544                        .start_slot(
545                            Icon::new(model_info.icon)
546                                .color(model_icon_color)
547                                .size(IconSize::Small),
548                        )
549                        .child(
550                            h_flex()
551                                .w_full()
552                                .pl_0p5()
553                                .gap_1p5()
554                                .w(px(240.))
555                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
556                        )
557                        .end_slot(div().pr_3().when(is_selected, |this| {
558                            this.child(
559                                Icon::new(IconName::Check)
560                                    .color(Color::Accent)
561                                    .size(IconSize::Small),
562                            )
563                        }))
564                        .into_any_element(),
565                )
566            }
567        }
568    }
569
570    fn render_footer(
571        &self,
572        _: &mut Window,
573        cx: &mut Context<Picker<Self>>,
574    ) -> Option<gpui::AnyElement> {
575        use feature_flags::FeatureFlagAppExt;
576
577        let plan = proto::Plan::ZedPro;
578
579        Some(
580            h_flex()
581                .w_full()
582                .border_t_1()
583                .border_color(cx.theme().colors().border_variant)
584                .p_1()
585                .gap_4()
586                .justify_between()
587                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
588                    this.child(match plan {
589                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
590                            .icon(IconName::ZedAssistant)
591                            .icon_size(IconSize::Small)
592                            .icon_color(Color::Muted)
593                            .icon_position(IconPosition::Start)
594                            .on_click(|_, window, cx| {
595                                window
596                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
597                            }),
598                        Plan::Free | Plan::ZedProTrial => Button::new(
599                            "try-pro",
600                            if plan == Plan::ZedProTrial {
601                                "Upgrade to Pro"
602                            } else {
603                                "Try Pro"
604                            },
605                        )
606                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
607                    })
608                })
609                .child(
610                    Button::new("configure", "Configure")
611                        .icon(IconName::Settings)
612                        .icon_size(IconSize::Small)
613                        .icon_color(Color::Muted)
614                        .icon_position(IconPosition::Start)
615                        .on_click(|_, window, cx| {
616                            let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
617                                zed_actions::agent::OpenConfiguration.boxed_clone()
618                            } else {
619                                zed_actions::assistant::ShowConfiguration.boxed_clone()
620                            };
621
622                            window.dispatch_action(configure_action, cx);
623                        }),
624                )
625                .into_any(),
626        )
627    }
628}