language_model_selector.rs

  1use std::sync::Arc;
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::{Assistant2FeatureFlag, ZedPro};
  5use gpui::{
  6    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  7    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  8};
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use picker::{Picker, PickerDelegate};
 13use proto::Plan;
 14use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 15
 16action_with_deprecated_aliases!(
 17    assistant,
 18    ToggleModelSelector,
 19    ["assistant2::ToggleModelSelector"]
 20);
 21
 22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 23
 24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 25
 26pub struct LanguageModelSelector {
 27    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 28    _authenticate_all_providers_task: Task<()>,
 29    _subscriptions: Vec<Subscription>,
 30}
 31
 32impl LanguageModelSelector {
 33    pub fn new(
 34        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 35        window: &mut Window,
 36        cx: &mut Context<Self>,
 37    ) -> Self {
 38        let on_model_changed = Arc::new(on_model_changed);
 39
 40        let all_models = Self::all_models(cx);
 41        let entries = all_models.entries();
 42
 43        let delegate = LanguageModelPickerDelegate {
 44            language_model_selector: cx.entity().downgrade(),
 45            on_model_changed: on_model_changed.clone(),
 46            all_models: Arc::new(all_models),
 47            selected_index: Self::get_active_model_index(&entries, cx),
 48            filtered_entries: entries,
 49        };
 50
 51        let picker = cx.new(|cx| {
 52            Picker::list(delegate, window, cx)
 53                .show_scrollbar(true)
 54                .width(rems(20.))
 55                .max_height(Some(rems(20.).into()))
 56        });
 57
 58        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 59
 60        LanguageModelSelector {
 61            picker,
 62            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 63            _subscriptions: vec![
 64                cx.subscribe_in(
 65                    &LanguageModelRegistry::global(cx),
 66                    window,
 67                    Self::handle_language_model_registry_event,
 68                ),
 69                subscription,
 70            ],
 71        }
 72    }
 73
 74    fn handle_language_model_registry_event(
 75        &mut self,
 76        _registry: &Entity<LanguageModelRegistry>,
 77        event: &language_model::Event,
 78        window: &mut Window,
 79        cx: &mut Context<Self>,
 80    ) {
 81        match event {
 82            language_model::Event::ProviderStateChanged
 83            | language_model::Event::AddedProvider(_)
 84            | language_model::Event::RemovedProvider(_) => {
 85                self.picker.update(cx, |this, cx| {
 86                    let query = this.query(cx);
 87                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 88                    // Update matches will automatically drop the previous task
 89                    // if we get a provider event again
 90                    this.update_matches(query, window, cx)
 91                });
 92            }
 93            _ => {}
 94        }
 95    }
 96
 97    /// Authenticates all providers in the [`LanguageModelRegistry`].
 98    ///
 99    /// We do this so that we can populate the language selector with all of the
100    /// models from the configured providers.
101    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
102        let authenticate_all_providers = LanguageModelRegistry::global(cx)
103            .read(cx)
104            .providers()
105            .iter()
106            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
107            .collect::<Vec<_>>();
108
109        cx.spawn(async move |_cx| {
110            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
111                if let Err(err) = authenticate_task.await {
112                    if matches!(err, AuthenticateError::CredentialsNotFound) {
113                        // Since we're authenticating these providers in the
114                        // background for the purposes of populating the
115                        // language selector, we don't care about providers
116                        // where the credentials are not found.
117                    } else {
118                        // Some providers have noisy failure states that we
119                        // don't want to spam the logs with every time the
120                        // language model selector is initialized.
121                        //
122                        // Ideally these should have more clear failure modes
123                        // that we know are safe to ignore here, like what we do
124                        // with `CredentialsNotFound` above.
125                        match provider_id.0.as_ref() {
126                            "lmstudio" | "ollama" => {
127                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
128                                //
129                                // These fail noisily, so we don't log them.
130                            }
131                            "copilot_chat" => {
132                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
133                            }
134                            _ => {
135                                log::error!(
136                                    "Failed to authenticate provider: {}: {err}",
137                                    provider_name.0
138                                );
139                            }
140                        }
141                    }
142                }
143            }
144        })
145    }
146
147    fn all_models(cx: &App) -> GroupedModels {
148        let mut recommended = Vec::new();
149        let mut recommended_set = HashSet::default();
150        for provider in LanguageModelRegistry::global(cx)
151            .read(cx)
152            .providers()
153            .iter()
154        {
155            let models = provider.recommended_models(cx);
156            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
157            recommended.extend(
158                provider
159                    .recommended_models(cx)
160                    .into_iter()
161                    .map(move |model| ModelInfo {
162                        model: model.clone(),
163                        icon: provider.icon(),
164                    }),
165            );
166        }
167
168        let other_models = LanguageModelRegistry::global(cx)
169            .read(cx)
170            .providers()
171            .iter()
172            .map(|provider| {
173                (
174                    provider.id(),
175                    provider
176                        .provided_models(cx)
177                        .into_iter()
178                        .filter_map(|model| {
179                            let not_included =
180                                !recommended_set.contains(&(model.provider_id(), model.id()));
181                            not_included.then(|| ModelInfo {
182                                model: model.clone(),
183                                icon: provider.icon(),
184                            })
185                        })
186                        .collect::<Vec<_>>(),
187                )
188            })
189            .collect::<IndexMap<_, _>>();
190
191        GroupedModels {
192            recommended,
193            other: other_models,
194        }
195    }
196
197    fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
198        let active_model = LanguageModelRegistry::read_global(cx).default_model();
199        entries
200            .iter()
201            .position(|entry| {
202                if let LanguageModelPickerEntry::Model(model) = entry {
203                    active_model
204                        .as_ref()
205                        .map(|active_model| {
206                            active_model.model.id() == model.model.id()
207                                && active_model.model.provider_id() == model.model.provider_id()
208                        })
209                        .unwrap_or_default()
210                } else {
211                    false
212                }
213            })
214            .unwrap_or(0)
215    }
216}
217
218impl EventEmitter<DismissEvent> for LanguageModelSelector {}
219
220impl Focusable for LanguageModelSelector {
221    fn focus_handle(&self, cx: &App) -> FocusHandle {
222        self.picker.focus_handle(cx)
223    }
224}
225
226impl Render for LanguageModelSelector {
227    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
228        self.picker.clone()
229    }
230}
231
232#[derive(IntoElement)]
233pub struct LanguageModelSelectorPopoverMenu<T, TT>
234where
235    T: PopoverTrigger + ButtonCommon,
236    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
237{
238    language_model_selector: Entity<LanguageModelSelector>,
239    trigger: T,
240    tooltip: TT,
241    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
242    anchor: Corner,
243}
244
245impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
246where
247    T: PopoverTrigger + ButtonCommon,
248    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
249{
250    pub fn new(
251        language_model_selector: Entity<LanguageModelSelector>,
252        trigger: T,
253        tooltip: TT,
254        anchor: Corner,
255    ) -> Self {
256        Self {
257            language_model_selector,
258            trigger,
259            tooltip,
260            handle: None,
261            anchor,
262        }
263    }
264
265    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
266        self.handle = Some(handle);
267        self
268    }
269}
270
271impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
272where
273    T: PopoverTrigger + ButtonCommon,
274    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
275{
276    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
277        let language_model_selector = self.language_model_selector.clone();
278
279        PopoverMenu::new("model-switcher")
280            .menu(move |_window, _cx| Some(language_model_selector.clone()))
281            .trigger_with_tooltip(self.trigger, self.tooltip)
282            .anchor(self.anchor)
283            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
284            .offset(gpui::Point {
285                x: px(0.0),
286                y: px(-2.0),
287            })
288    }
289}
290
291#[derive(Clone)]
292struct ModelInfo {
293    model: Arc<dyn LanguageModel>,
294    icon: IconName,
295}
296
297pub struct LanguageModelPickerDelegate {
298    language_model_selector: WeakEntity<LanguageModelSelector>,
299    on_model_changed: OnModelChanged,
300    all_models: Arc<GroupedModels>,
301    filtered_entries: Vec<LanguageModelPickerEntry>,
302    selected_index: usize,
303}
304
305struct GroupedModels {
306    recommended: Vec<ModelInfo>,
307    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
308}
309
310impl GroupedModels {
311    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
312        let mut entries = Vec::new();
313
314        if !self.recommended.is_empty() {
315            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
316            entries.extend(
317                self.recommended
318                    .iter()
319                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
320            );
321        }
322
323        for models in self.other.values() {
324            if models.is_empty() {
325                continue;
326            }
327            entries.push(LanguageModelPickerEntry::Separator(
328                models[0].model.provider_name().0,
329            ));
330            entries.extend(
331                models
332                    .iter()
333                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
334            );
335        }
336        entries
337    }
338}
339
340enum LanguageModelPickerEntry {
341    Model(ModelInfo),
342    Separator(SharedString),
343}
344
345impl PickerDelegate for LanguageModelPickerDelegate {
346    type ListItem = AnyElement;
347
348    fn match_count(&self) -> usize {
349        self.filtered_entries.len()
350    }
351
352    fn selected_index(&self) -> usize {
353        self.selected_index
354    }
355
356    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
357        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
358        cx.notify();
359    }
360
361    fn can_select(
362        &mut self,
363        ix: usize,
364        _window: &mut Window,
365        _cx: &mut Context<Picker<Self>>,
366    ) -> bool {
367        match self.filtered_entries.get(ix) {
368            Some(LanguageModelPickerEntry::Model(_)) => true,
369            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
370        }
371    }
372
373    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
374        "Select a model…".into()
375    }
376
377    fn update_matches(
378        &mut self,
379        query: String,
380        window: &mut Window,
381        cx: &mut Context<Picker<Self>>,
382    ) -> Task<()> {
383        let all_models = self.all_models.clone();
384        let current_index = self.selected_index;
385
386        let language_model_registry = LanguageModelRegistry::global(cx);
387
388        let configured_providers = language_model_registry
389            .read(cx)
390            .providers()
391            .iter()
392            .filter(|provider| provider.is_authenticated(cx))
393            .map(|provider| provider.id())
394            .collect::<Vec<_>>();
395
396        cx.spawn_in(window, async move |this, cx| {
397            let filtered_models = cx
398                .background_spawn(async move {
399                    let filter_models = |model_infos: &[ModelInfo]| {
400                        model_infos
401                            .iter()
402                            .filter(|model_info| {
403                                model_info
404                                    .model
405                                    .name()
406                                    .0
407                                    .to_lowercase()
408                                    .contains(&query.to_lowercase())
409                            })
410                            .cloned()
411                            .collect::<Vec<_>>()
412                    };
413
414                    let recommended_models = filter_models(&all_models.recommended);
415                    let mut other_models = IndexMap::default();
416                    for (provider_id, models) in &all_models.other {
417                        if configured_providers.contains(&provider_id) {
418                            other_models.insert(provider_id.clone(), filter_models(models));
419                        }
420                    }
421                    GroupedModels {
422                        recommended: recommended_models,
423                        other: other_models,
424                    }
425                })
426                .await;
427
428            this.update_in(cx, |this, window, cx| {
429                this.delegate.filtered_entries = filtered_models.entries();
430                // Preserve selection focus
431                let new_index = if current_index >= this.delegate.filtered_entries.len() {
432                    0
433                } else {
434                    current_index
435                };
436                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
437                cx.notify();
438            })
439            .ok();
440        })
441    }
442
443    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
444        if let Some(LanguageModelPickerEntry::Model(model_info)) =
445            self.filtered_entries.get(self.selected_index)
446        {
447            let model = model_info.model.clone();
448            (self.on_model_changed)(model.clone(), cx);
449
450            let current_index = self.selected_index;
451            self.set_selected_index(current_index, window, cx);
452
453            cx.emit(DismissEvent);
454        }
455    }
456
457    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
458        self.language_model_selector
459            .update(cx, |_this, cx| cx.emit(DismissEvent))
460            .ok();
461    }
462
463    fn render_match(
464        &self,
465        ix: usize,
466        selected: bool,
467        _: &mut Window,
468        cx: &mut Context<Picker<Self>>,
469    ) -> Option<Self::ListItem> {
470        match self.filtered_entries.get(ix)? {
471            LanguageModelPickerEntry::Separator(title) => Some(
472                div()
473                    .px_2()
474                    .pb_1()
475                    .when(ix > 1, |this| {
476                        this.mt_1()
477                            .pt_2()
478                            .border_t_1()
479                            .border_color(cx.theme().colors().border_variant)
480                    })
481                    .child(
482                        Label::new(title)
483                            .size(LabelSize::XSmall)
484                            .color(Color::Muted),
485                    )
486                    .into_any_element(),
487            ),
488            LanguageModelPickerEntry::Model(model_info) => {
489                let active_model = LanguageModelRegistry::read_global(cx).default_model();
490
491                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
492                let active_model_id = active_model.map(|m| m.model.id());
493
494                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
495                    && Some(model_info.model.id()) == active_model_id;
496
497                let model_icon_color = if is_selected {
498                    Color::Accent
499                } else {
500                    Color::Muted
501                };
502
503                Some(
504                    ListItem::new(ix)
505                        .inset(true)
506                        .spacing(ListItemSpacing::Sparse)
507                        .toggle_state(selected)
508                        .start_slot(
509                            Icon::new(model_info.icon)
510                                .color(model_icon_color)
511                                .size(IconSize::Small),
512                        )
513                        .child(
514                            h_flex()
515                                .w_full()
516                                .pl_0p5()
517                                .gap_1p5()
518                                .w(px(240.))
519                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
520                        )
521                        .end_slot(div().pr_3().when(is_selected, |this| {
522                            this.child(
523                                Icon::new(IconName::Check)
524                                    .color(Color::Accent)
525                                    .size(IconSize::Small),
526                            )
527                        }))
528                        .into_any_element(),
529                )
530            }
531        }
532    }
533
534    fn render_footer(
535        &self,
536        _: &mut Window,
537        cx: &mut Context<Picker<Self>>,
538    ) -> Option<gpui::AnyElement> {
539        use feature_flags::FeatureFlagAppExt;
540
541        let plan = proto::Plan::ZedPro;
542        let is_trial = false;
543
544        Some(
545            h_flex()
546                .w_full()
547                .border_t_1()
548                .border_color(cx.theme().colors().border_variant)
549                .p_1()
550                .gap_4()
551                .justify_between()
552                .when(cx.has_flag::<ZedPro>(), |this| {
553                    this.child(match plan {
554                        // Already a Zed Pro subscriber
555                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
556                            .icon(IconName::ZedAssistant)
557                            .icon_size(IconSize::Small)
558                            .icon_color(Color::Muted)
559                            .icon_position(IconPosition::Start)
560                            .on_click(|_, window, cx| {
561                                window
562                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
563                            }),
564                        // Free user
565                        Plan::Free => Button::new(
566                            "try-pro",
567                            if is_trial {
568                                "Upgrade to Pro"
569                            } else {
570                                "Try Pro"
571                            },
572                        )
573                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
574                    })
575                })
576                .child(
577                    Button::new("configure", "Configure")
578                        .icon(IconName::Settings)
579                        .icon_size(IconSize::Small)
580                        .icon_color(Color::Muted)
581                        .icon_position(IconPosition::Start)
582                        .on_click(|_, window, cx| {
583                            let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
584                                zed_actions::agent::OpenConfiguration.boxed_clone()
585                            } else {
586                                zed_actions::assistant::ShowConfiguration.boxed_clone()
587                            };
588
589                            window.dispatch_action(configure_action, cx);
590                        }),
591                )
592                .into_any(),
593        )
594    }
595}