language_model_selector.rs

  1use std::sync::Arc;
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
  5use gpui::{
  6    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  7    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 11    LanguageModelRegistry,
 12};
 13use picker::{Picker, PickerDelegate};
 14use proto::Plan;
 15use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 16
 17action_with_deprecated_aliases!(
 18    assistant,
 19    ToggleModelSelector,
 20    ["assistant2::ToggleModelSelector"]
 21);
 22
 23const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 24
 25type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 26type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 27
 28pub struct LanguageModelSelector {
 29    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 30    _authenticate_all_providers_task: Task<()>,
 31    _subscriptions: Vec<Subscription>,
 32}
 33
 34impl LanguageModelSelector {
 35    pub fn new(
 36        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 37        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 38        window: &mut Window,
 39        cx: &mut Context<Self>,
 40    ) -> Self {
 41        let on_model_changed = Arc::new(on_model_changed);
 42
 43        let all_models = Self::all_models(cx);
 44        let entries = all_models.entries();
 45
 46        let delegate = LanguageModelPickerDelegate {
 47            language_model_selector: cx.entity().downgrade(),
 48            on_model_changed: on_model_changed.clone(),
 49            all_models: Arc::new(all_models),
 50            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 51            filtered_entries: entries,
 52            get_active_model: Arc::new(get_active_model),
 53        };
 54
 55        let picker = cx.new(|cx| {
 56            Picker::list(delegate, window, cx)
 57                .show_scrollbar(true)
 58                .width(rems(20.))
 59                .max_height(Some(rems(20.).into()))
 60        });
 61
 62        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 63
 64        LanguageModelSelector {
 65            picker,
 66            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 67            _subscriptions: vec![
 68                cx.subscribe_in(
 69                    &LanguageModelRegistry::global(cx),
 70                    window,
 71                    Self::handle_language_model_registry_event,
 72                ),
 73                subscription,
 74            ],
 75        }
 76    }
 77
 78    fn handle_language_model_registry_event(
 79        &mut self,
 80        _registry: &Entity<LanguageModelRegistry>,
 81        event: &language_model::Event,
 82        window: &mut Window,
 83        cx: &mut Context<Self>,
 84    ) {
 85        match event {
 86            language_model::Event::ProviderStateChanged
 87            | language_model::Event::AddedProvider(_)
 88            | language_model::Event::RemovedProvider(_) => {
 89                self.picker.update(cx, |this, cx| {
 90                    let query = this.query(cx);
 91                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 92                    // Update matches will automatically drop the previous task
 93                    // if we get a provider event again
 94                    this.update_matches(query, window, cx)
 95                });
 96            }
 97            _ => {}
 98        }
 99    }
100
101    /// Authenticates all providers in the [`LanguageModelRegistry`].
102    ///
103    /// We do this so that we can populate the language selector with all of the
104    /// models from the configured providers.
105    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
106        let authenticate_all_providers = LanguageModelRegistry::global(cx)
107            .read(cx)
108            .providers()
109            .iter()
110            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
111            .collect::<Vec<_>>();
112
113        cx.spawn(async move |_cx| {
114            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
115                if let Err(err) = authenticate_task.await {
116                    if matches!(err, AuthenticateError::CredentialsNotFound) {
117                        // Since we're authenticating these providers in the
118                        // background for the purposes of populating the
119                        // language selector, we don't care about providers
120                        // where the credentials are not found.
121                    } else {
122                        // Some providers have noisy failure states that we
123                        // don't want to spam the logs with every time the
124                        // language model selector is initialized.
125                        //
126                        // Ideally these should have more clear failure modes
127                        // that we know are safe to ignore here, like what we do
128                        // with `CredentialsNotFound` above.
129                        match provider_id.0.as_ref() {
130                            "lmstudio" | "ollama" => {
131                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
132                                //
133                                // These fail noisily, so we don't log them.
134                            }
135                            "copilot_chat" => {
136                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
137                            }
138                            _ => {
139                                log::error!(
140                                    "Failed to authenticate provider: {}: {err}",
141                                    provider_name.0
142                                );
143                            }
144                        }
145                    }
146                }
147            }
148        })
149    }
150
151    fn all_models(cx: &App) -> GroupedModels {
152        let mut recommended = Vec::new();
153        let mut recommended_set = HashSet::default();
154        for provider in LanguageModelRegistry::global(cx)
155            .read(cx)
156            .providers()
157            .iter()
158        {
159            let models = provider.recommended_models(cx);
160            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
161            recommended.extend(
162                provider
163                    .recommended_models(cx)
164                    .into_iter()
165                    .map(move |model| ModelInfo {
166                        model: model.clone(),
167                        icon: provider.icon(),
168                    }),
169            );
170        }
171
172        let other_models = LanguageModelRegistry::global(cx)
173            .read(cx)
174            .providers()
175            .iter()
176            .map(|provider| {
177                (
178                    provider.id(),
179                    provider
180                        .provided_models(cx)
181                        .into_iter()
182                        .filter_map(|model| {
183                            let not_included =
184                                !recommended_set.contains(&(model.provider_id(), model.id()));
185                            not_included.then(|| ModelInfo {
186                                model: model.clone(),
187                                icon: provider.icon(),
188                            })
189                        })
190                        .collect::<Vec<_>>(),
191                )
192            })
193            .collect::<IndexMap<_, _>>();
194
195        GroupedModels {
196            recommended,
197            other: other_models,
198        }
199    }
200
201    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
202        (self.picker.read(cx).delegate.get_active_model)(cx)
203    }
204
205    fn get_active_model_index(
206        entries: &[LanguageModelPickerEntry],
207        active_model: Option<ConfiguredModel>,
208    ) -> usize {
209        entries
210            .iter()
211            .position(|entry| {
212                if let LanguageModelPickerEntry::Model(model) = entry {
213                    active_model
214                        .as_ref()
215                        .map(|active_model| {
216                            active_model.model.id() == model.model.id()
217                                && active_model.provider.id() == model.model.provider_id()
218                        })
219                        .unwrap_or_default()
220                } else {
221                    false
222                }
223            })
224            .unwrap_or(0)
225    }
226}
227
228impl EventEmitter<DismissEvent> for LanguageModelSelector {}
229
230impl Focusable for LanguageModelSelector {
231    fn focus_handle(&self, cx: &App) -> FocusHandle {
232        self.picker.focus_handle(cx)
233    }
234}
235
236impl Render for LanguageModelSelector {
237    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
238        self.picker.clone()
239    }
240}
241
242#[derive(IntoElement)]
243pub struct LanguageModelSelectorPopoverMenu<T, TT>
244where
245    T: PopoverTrigger + ButtonCommon,
246    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
247{
248    language_model_selector: Entity<LanguageModelSelector>,
249    trigger: T,
250    tooltip: TT,
251    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
252    anchor: Corner,
253}
254
255impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
256where
257    T: PopoverTrigger + ButtonCommon,
258    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
259{
260    pub fn new(
261        language_model_selector: Entity<LanguageModelSelector>,
262        trigger: T,
263        tooltip: TT,
264        anchor: Corner,
265    ) -> Self {
266        Self {
267            language_model_selector,
268            trigger,
269            tooltip,
270            handle: None,
271            anchor,
272        }
273    }
274
275    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
276        self.handle = Some(handle);
277        self
278    }
279}
280
281impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
282where
283    T: PopoverTrigger + ButtonCommon,
284    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
285{
286    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
287        let language_model_selector = self.language_model_selector.clone();
288
289        PopoverMenu::new("model-switcher")
290            .menu(move |_window, _cx| Some(language_model_selector.clone()))
291            .trigger_with_tooltip(self.trigger, self.tooltip)
292            .anchor(self.anchor)
293            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
294            .offset(gpui::Point {
295                x: px(0.0),
296                y: px(-2.0),
297            })
298    }
299}
300
301#[derive(Clone)]
302struct ModelInfo {
303    model: Arc<dyn LanguageModel>,
304    icon: IconName,
305}
306
307pub struct LanguageModelPickerDelegate {
308    language_model_selector: WeakEntity<LanguageModelSelector>,
309    on_model_changed: OnModelChanged,
310    get_active_model: GetActiveModel,
311    all_models: Arc<GroupedModels>,
312    filtered_entries: Vec<LanguageModelPickerEntry>,
313    selected_index: usize,
314}
315
316struct GroupedModels {
317    recommended: Vec<ModelInfo>,
318    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
319}
320
321impl GroupedModels {
322    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
323        let mut entries = Vec::new();
324
325        if !self.recommended.is_empty() {
326            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
327            entries.extend(
328                self.recommended
329                    .iter()
330                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
331            );
332        }
333
334        for models in self.other.values() {
335            if models.is_empty() {
336                continue;
337            }
338            entries.push(LanguageModelPickerEntry::Separator(
339                models[0].model.provider_name().0,
340            ));
341            entries.extend(
342                models
343                    .iter()
344                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
345            );
346        }
347        entries
348    }
349}
350
351enum LanguageModelPickerEntry {
352    Model(ModelInfo),
353    Separator(SharedString),
354}
355
356impl PickerDelegate for LanguageModelPickerDelegate {
357    type ListItem = AnyElement;
358
359    fn match_count(&self) -> usize {
360        self.filtered_entries.len()
361    }
362
363    fn selected_index(&self) -> usize {
364        self.selected_index
365    }
366
367    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
368        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
369        cx.notify();
370    }
371
372    fn can_select(
373        &mut self,
374        ix: usize,
375        _window: &mut Window,
376        _cx: &mut Context<Picker<Self>>,
377    ) -> bool {
378        match self.filtered_entries.get(ix) {
379            Some(LanguageModelPickerEntry::Model(_)) => true,
380            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
381        }
382    }
383
384    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
385        "Select a model…".into()
386    }
387
388    fn update_matches(
389        &mut self,
390        query: String,
391        window: &mut Window,
392        cx: &mut Context<Picker<Self>>,
393    ) -> Task<()> {
394        let all_models = self.all_models.clone();
395        let current_index = self.selected_index;
396
397        let language_model_registry = LanguageModelRegistry::global(cx);
398
399        let configured_providers = language_model_registry
400            .read(cx)
401            .providers()
402            .iter()
403            .filter(|provider| provider.is_authenticated(cx))
404            .map(|provider| provider.id())
405            .collect::<Vec<_>>();
406
407        cx.spawn_in(window, async move |this, cx| {
408            let filtered_models = cx
409                .background_spawn(async move {
410                    let matches = |info: &ModelInfo| {
411                        info.model
412                            .name()
413                            .0
414                            .to_lowercase()
415                            .contains(&query.to_lowercase())
416                    };
417
418                    let recommended_models = all_models
419                        .recommended
420                        .iter()
421                        .filter(|r| {
422                            configured_providers.contains(&r.model.provider_id()) && matches(r)
423                        })
424                        .cloned()
425                        .collect();
426                    let mut other_models = IndexMap::default();
427                    for (provider_id, models) in &all_models.other {
428                        if configured_providers.contains(&provider_id) {
429                            other_models.insert(
430                                provider_id.clone(),
431                                models
432                                    .iter()
433                                    .filter(|m| matches(m))
434                                    .cloned()
435                                    .collect::<Vec<_>>(),
436                            );
437                        }
438                    }
439                    GroupedModels {
440                        recommended: recommended_models,
441                        other: other_models,
442                    }
443                })
444                .await;
445
446            this.update_in(cx, |this, window, cx| {
447                this.delegate.filtered_entries = filtered_models.entries();
448                // Preserve selection focus
449                let new_index = if current_index >= this.delegate.filtered_entries.len() {
450                    0
451                } else {
452                    current_index
453                };
454                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
455                cx.notify();
456            })
457            .ok();
458        })
459    }
460
461    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
462        if let Some(LanguageModelPickerEntry::Model(model_info)) =
463            self.filtered_entries.get(self.selected_index)
464        {
465            let model = model_info.model.clone();
466            (self.on_model_changed)(model.clone(), cx);
467
468            let current_index = self.selected_index;
469            self.set_selected_index(current_index, window, cx);
470
471            cx.emit(DismissEvent);
472        }
473    }
474
475    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
476        self.language_model_selector
477            .update(cx, |_this, cx| cx.emit(DismissEvent))
478            .ok();
479    }
480
481    fn render_match(
482        &self,
483        ix: usize,
484        selected: bool,
485        _: &mut Window,
486        cx: &mut Context<Picker<Self>>,
487    ) -> Option<Self::ListItem> {
488        match self.filtered_entries.get(ix)? {
489            LanguageModelPickerEntry::Separator(title) => Some(
490                div()
491                    .px_2()
492                    .pb_1()
493                    .when(ix > 1, |this| {
494                        this.mt_1()
495                            .pt_2()
496                            .border_t_1()
497                            .border_color(cx.theme().colors().border_variant)
498                    })
499                    .child(
500                        Label::new(title)
501                            .size(LabelSize::XSmall)
502                            .color(Color::Muted),
503                    )
504                    .into_any_element(),
505            ),
506            LanguageModelPickerEntry::Model(model_info) => {
507                let active_model = (self.get_active_model)(cx);
508                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
509                let active_model_id = active_model.map(|m| m.model.id());
510
511                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
512                    && Some(model_info.model.id()) == active_model_id;
513
514                let model_icon_color = if is_selected {
515                    Color::Accent
516                } else {
517                    Color::Muted
518                };
519
520                Some(
521                    ListItem::new(ix)
522                        .inset(true)
523                        .spacing(ListItemSpacing::Sparse)
524                        .toggle_state(selected)
525                        .start_slot(
526                            Icon::new(model_info.icon)
527                                .color(model_icon_color)
528                                .size(IconSize::Small),
529                        )
530                        .child(
531                            h_flex()
532                                .w_full()
533                                .pl_0p5()
534                                .gap_1p5()
535                                .w(px(240.))
536                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
537                        )
538                        .end_slot(div().pr_3().when(is_selected, |this| {
539                            this.child(
540                                Icon::new(IconName::Check)
541                                    .color(Color::Accent)
542                                    .size(IconSize::Small),
543                            )
544                        }))
545                        .into_any_element(),
546                )
547            }
548        }
549    }
550
551    fn render_footer(
552        &self,
553        _: &mut Window,
554        cx: &mut Context<Picker<Self>>,
555    ) -> Option<gpui::AnyElement> {
556        use feature_flags::FeatureFlagAppExt;
557
558        let plan = proto::Plan::ZedPro;
559
560        Some(
561            h_flex()
562                .w_full()
563                .border_t_1()
564                .border_color(cx.theme().colors().border_variant)
565                .p_1()
566                .gap_4()
567                .justify_between()
568                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
569                    this.child(match plan {
570                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
571                            .icon(IconName::ZedAssistant)
572                            .icon_size(IconSize::Small)
573                            .icon_color(Color::Muted)
574                            .icon_position(IconPosition::Start)
575                            .on_click(|_, window, cx| {
576                                window
577                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
578                            }),
579                        Plan::Free | Plan::ZedProTrial => Button::new(
580                            "try-pro",
581                            if plan == Plan::ZedProTrial {
582                                "Upgrade to Pro"
583                            } else {
584                                "Try Pro"
585                            },
586                        )
587                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
588                    })
589                })
590                .child(
591                    Button::new("configure", "Configure")
592                        .icon(IconName::Settings)
593                        .icon_size(IconSize::Small)
594                        .icon_color(Color::Muted)
595                        .icon_position(IconPosition::Start)
596                        .on_click(|_, window, cx| {
597                            let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
598                                zed_actions::agent::OpenConfiguration.boxed_clone()
599                            } else {
600                                zed_actions::assistant::ShowConfiguration.boxed_clone()
601                            };
602
603                            window.dispatch_action(configure_action, cx);
604                        }),
605                )
606                .into_any(),
607        )
608    }
609}