language_model_selector.rs

  1use std::sync::Arc;
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::{Assistant2FeatureFlag, ZedPro};
  5use gpui::{
  6    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  7    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  8};
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use picker::{Picker, PickerDelegate};
 13use proto::Plan;
 14use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 15
 16action_with_deprecated_aliases!(
 17    assistant,
 18    ToggleModelSelector,
 19    ["assistant2::ToggleModelSelector"]
 20);
 21
 22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 23
 24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 25
 26pub struct LanguageModelSelector {
 27    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 28    _authenticate_all_providers_task: Task<()>,
 29    _subscriptions: Vec<Subscription>,
 30}
 31
 32impl LanguageModelSelector {
 33    pub fn new(
 34        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 35        window: &mut Window,
 36        cx: &mut Context<Self>,
 37    ) -> Self {
 38        let on_model_changed = Arc::new(on_model_changed);
 39
 40        let all_models = Self::all_models(cx);
 41        let entries = all_models.entries();
 42
 43        let delegate = LanguageModelPickerDelegate {
 44            language_model_selector: cx.entity().downgrade(),
 45            on_model_changed: on_model_changed.clone(),
 46            all_models: Arc::new(all_models),
 47            selected_index: Self::get_active_model_index(&entries, cx),
 48            filtered_entries: entries,
 49        };
 50
 51        let picker = cx.new(|cx| {
 52            Picker::list(delegate, window, cx)
 53                .show_scrollbar(true)
 54                .width(rems(20.))
 55                .max_height(Some(rems(20.).into()))
 56        });
 57
 58        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 59
 60        LanguageModelSelector {
 61            picker,
 62            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 63            _subscriptions: vec![
 64                cx.subscribe_in(
 65                    &LanguageModelRegistry::global(cx),
 66                    window,
 67                    Self::handle_language_model_registry_event,
 68                ),
 69                subscription,
 70            ],
 71        }
 72    }
 73
 74    fn handle_language_model_registry_event(
 75        &mut self,
 76        _registry: &Entity<LanguageModelRegistry>,
 77        event: &language_model::Event,
 78        window: &mut Window,
 79        cx: &mut Context<Self>,
 80    ) {
 81        match event {
 82            language_model::Event::ProviderStateChanged
 83            | language_model::Event::AddedProvider(_)
 84            | language_model::Event::RemovedProvider(_) => {
 85                self.picker.update(cx, |this, cx| {
 86                    let query = this.query(cx);
 87                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 88                    // Update matches will automatically drop the previous task
 89                    // if we get a provider event again
 90                    this.update_matches(query, window, cx)
 91                });
 92            }
 93            _ => {}
 94        }
 95    }
 96
 97    /// Authenticates all providers in the [`LanguageModelRegistry`].
 98    ///
 99    /// We do this so that we can populate the language selector with all of the
100    /// models from the configured providers.
101    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
102        let authenticate_all_providers = LanguageModelRegistry::global(cx)
103            .read(cx)
104            .providers()
105            .iter()
106            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
107            .collect::<Vec<_>>();
108
109        cx.spawn(async move |_cx| {
110            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
111                if let Err(err) = authenticate_task.await {
112                    if matches!(err, AuthenticateError::CredentialsNotFound) {
113                        // Since we're authenticating these providers in the
114                        // background for the purposes of populating the
115                        // language selector, we don't care about providers
116                        // where the credentials are not found.
117                    } else {
118                        // Some providers have noisy failure states that we
119                        // don't want to spam the logs with every time the
120                        // language model selector is initialized.
121                        //
122                        // Ideally these should have more clear failure modes
123                        // that we know are safe to ignore here, like what we do
124                        // with `CredentialsNotFound` above.
125                        match provider_id.0.as_ref() {
126                            "lmstudio" | "ollama" => {
127                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
128                                //
129                                // These fail noisily, so we don't log them.
130                            }
131                            "copilot_chat" => {
132                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
133                            }
134                            _ => {
135                                log::error!(
136                                    "Failed to authenticate provider: {}: {err}",
137                                    provider_name.0
138                                );
139                            }
140                        }
141                    }
142                }
143            }
144        })
145    }
146
147    fn all_models(cx: &App) -> GroupedModels {
148        let mut recommended = Vec::new();
149        let mut recommended_set = HashSet::default();
150        for provider in LanguageModelRegistry::global(cx)
151            .read(cx)
152            .providers()
153            .iter()
154        {
155            let models = provider.recommended_models(cx);
156            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
157            recommended.extend(
158                provider
159                    .recommended_models(cx)
160                    .into_iter()
161                    .map(move |model| ModelInfo {
162                        model: model.clone(),
163                        icon: provider.icon(),
164                    }),
165            );
166        }
167
168        let other_models = LanguageModelRegistry::global(cx)
169            .read(cx)
170            .providers()
171            .iter()
172            .map(|provider| {
173                (
174                    provider.id(),
175                    provider
176                        .provided_models(cx)
177                        .into_iter()
178                        .filter_map(|model| {
179                            let not_included =
180                                !recommended_set.contains(&(model.provider_id(), model.id()));
181                            not_included.then(|| ModelInfo {
182                                model: model.clone(),
183                                icon: provider.icon(),
184                            })
185                        })
186                        .collect::<Vec<_>>(),
187                )
188            })
189            .collect::<IndexMap<_, _>>();
190
191        GroupedModels {
192            recommended,
193            other: other_models,
194        }
195    }
196
197    fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
198        let active_model = LanguageModelRegistry::read_global(cx).default_model();
199        entries
200            .iter()
201            .position(|entry| {
202                if let LanguageModelPickerEntry::Model(model) = entry {
203                    active_model
204                        .as_ref()
205                        .map(|active_model| {
206                            active_model.model.id() == model.model.id()
207                                && active_model.model.provider_id() == model.model.provider_id()
208                        })
209                        .unwrap_or_default()
210                } else {
211                    false
212                }
213            })
214            .unwrap_or(0)
215    }
216}
217
218impl EventEmitter<DismissEvent> for LanguageModelSelector {}
219
220impl Focusable for LanguageModelSelector {
221    fn focus_handle(&self, cx: &App) -> FocusHandle {
222        self.picker.focus_handle(cx)
223    }
224}
225
226impl Render for LanguageModelSelector {
227    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
228        self.picker.clone()
229    }
230}
231
232#[derive(IntoElement)]
233pub struct LanguageModelSelectorPopoverMenu<T, TT>
234where
235    T: PopoverTrigger + ButtonCommon,
236    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
237{
238    language_model_selector: Entity<LanguageModelSelector>,
239    trigger: T,
240    tooltip: TT,
241    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
242    anchor: Corner,
243}
244
245impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
246where
247    T: PopoverTrigger + ButtonCommon,
248    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
249{
250    pub fn new(
251        language_model_selector: Entity<LanguageModelSelector>,
252        trigger: T,
253        tooltip: TT,
254        anchor: Corner,
255    ) -> Self {
256        Self {
257            language_model_selector,
258            trigger,
259            tooltip,
260            handle: None,
261            anchor,
262        }
263    }
264
265    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
266        self.handle = Some(handle);
267        self
268    }
269}
270
271impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
272where
273    T: PopoverTrigger + ButtonCommon,
274    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
275{
276    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
277        let language_model_selector = self.language_model_selector.clone();
278
279        PopoverMenu::new("model-switcher")
280            .menu(move |_window, _cx| Some(language_model_selector.clone()))
281            .trigger_with_tooltip(self.trigger, self.tooltip)
282            .anchor(self.anchor)
283            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
284            .offset(gpui::Point {
285                x: px(0.0),
286                y: px(-2.0),
287            })
288    }
289}
290
291#[derive(Clone)]
292struct ModelInfo {
293    model: Arc<dyn LanguageModel>,
294    icon: IconName,
295}
296
297pub struct LanguageModelPickerDelegate {
298    language_model_selector: WeakEntity<LanguageModelSelector>,
299    on_model_changed: OnModelChanged,
300    all_models: Arc<GroupedModels>,
301    filtered_entries: Vec<LanguageModelPickerEntry>,
302    selected_index: usize,
303}
304
305struct GroupedModels {
306    recommended: Vec<ModelInfo>,
307    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
308}
309
310impl GroupedModels {
311    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
312        let mut entries = Vec::new();
313
314        if !self.recommended.is_empty() {
315            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
316            entries.extend(
317                self.recommended
318                    .iter()
319                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
320            );
321        }
322
323        for models in self.other.values() {
324            if models.is_empty() {
325                continue;
326            }
327            entries.push(LanguageModelPickerEntry::Separator(
328                models[0].model.provider_name().0,
329            ));
330            entries.extend(
331                models
332                    .iter()
333                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
334            );
335        }
336        entries
337    }
338}
339
340enum LanguageModelPickerEntry {
341    Model(ModelInfo),
342    Separator(SharedString),
343}
344
345impl PickerDelegate for LanguageModelPickerDelegate {
346    type ListItem = AnyElement;
347
348    fn match_count(&self) -> usize {
349        self.filtered_entries.len()
350    }
351
352    fn selected_index(&self) -> usize {
353        self.selected_index
354    }
355
356    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
357        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
358        cx.notify();
359    }
360
361    fn can_select(
362        &mut self,
363        ix: usize,
364        _window: &mut Window,
365        _cx: &mut Context<Picker<Self>>,
366    ) -> bool {
367        match self.filtered_entries.get(ix) {
368            Some(LanguageModelPickerEntry::Model(_)) => true,
369            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
370        }
371    }
372
373    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
374        "Select a model…".into()
375    }
376
377    fn update_matches(
378        &mut self,
379        query: String,
380        window: &mut Window,
381        cx: &mut Context<Picker<Self>>,
382    ) -> Task<()> {
383        let all_models = self.all_models.clone();
384        let current_index = self.selected_index;
385
386        let language_model_registry = LanguageModelRegistry::global(cx);
387
388        let configured_providers = language_model_registry
389            .read(cx)
390            .providers()
391            .iter()
392            .filter(|provider| provider.is_authenticated(cx))
393            .map(|provider| provider.id())
394            .collect::<Vec<_>>();
395
396        cx.spawn_in(window, async move |this, cx| {
397            let filtered_models = cx
398                .background_spawn(async move {
399                    let matches = |info: &ModelInfo| {
400                        info.model
401                            .name()
402                            .0
403                            .to_lowercase()
404                            .contains(&query.to_lowercase())
405                    };
406
407                    let recommended_models = all_models
408                        .recommended
409                        .iter()
410                        .filter(|r| {
411                            configured_providers.contains(&r.model.provider_id()) && matches(r)
412                        })
413                        .cloned()
414                        .collect();
415                    let mut other_models = IndexMap::default();
416                    for (provider_id, models) in &all_models.other {
417                        if configured_providers.contains(&provider_id) {
418                            other_models.insert(
419                                provider_id.clone(),
420                                models
421                                    .iter()
422                                    .filter(|m| matches(m))
423                                    .cloned()
424                                    .collect::<Vec<_>>(),
425                            );
426                        }
427                    }
428                    GroupedModels {
429                        recommended: recommended_models,
430                        other: other_models,
431                    }
432                })
433                .await;
434
435            this.update_in(cx, |this, window, cx| {
436                this.delegate.filtered_entries = filtered_models.entries();
437                // Preserve selection focus
438                let new_index = if current_index >= this.delegate.filtered_entries.len() {
439                    0
440                } else {
441                    current_index
442                };
443                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
444                cx.notify();
445            })
446            .ok();
447        })
448    }
449
450    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
451        if let Some(LanguageModelPickerEntry::Model(model_info)) =
452            self.filtered_entries.get(self.selected_index)
453        {
454            let model = model_info.model.clone();
455            (self.on_model_changed)(model.clone(), cx);
456
457            let current_index = self.selected_index;
458            self.set_selected_index(current_index, window, cx);
459
460            cx.emit(DismissEvent);
461        }
462    }
463
464    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
465        self.language_model_selector
466            .update(cx, |_this, cx| cx.emit(DismissEvent))
467            .ok();
468    }
469
470    fn render_match(
471        &self,
472        ix: usize,
473        selected: bool,
474        _: &mut Window,
475        cx: &mut Context<Picker<Self>>,
476    ) -> Option<Self::ListItem> {
477        match self.filtered_entries.get(ix)? {
478            LanguageModelPickerEntry::Separator(title) => Some(
479                div()
480                    .px_2()
481                    .pb_1()
482                    .when(ix > 1, |this| {
483                        this.mt_1()
484                            .pt_2()
485                            .border_t_1()
486                            .border_color(cx.theme().colors().border_variant)
487                    })
488                    .child(
489                        Label::new(title)
490                            .size(LabelSize::XSmall)
491                            .color(Color::Muted),
492                    )
493                    .into_any_element(),
494            ),
495            LanguageModelPickerEntry::Model(model_info) => {
496                let active_model = LanguageModelRegistry::read_global(cx).default_model();
497
498                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
499                let active_model_id = active_model.map(|m| m.model.id());
500
501                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
502                    && Some(model_info.model.id()) == active_model_id;
503
504                let model_icon_color = if is_selected {
505                    Color::Accent
506                } else {
507                    Color::Muted
508                };
509
510                Some(
511                    ListItem::new(ix)
512                        .inset(true)
513                        .spacing(ListItemSpacing::Sparse)
514                        .toggle_state(selected)
515                        .start_slot(
516                            Icon::new(model_info.icon)
517                                .color(model_icon_color)
518                                .size(IconSize::Small),
519                        )
520                        .child(
521                            h_flex()
522                                .w_full()
523                                .pl_0p5()
524                                .gap_1p5()
525                                .w(px(240.))
526                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
527                        )
528                        .end_slot(div().pr_3().when(is_selected, |this| {
529                            this.child(
530                                Icon::new(IconName::Check)
531                                    .color(Color::Accent)
532                                    .size(IconSize::Small),
533                            )
534                        }))
535                        .into_any_element(),
536                )
537            }
538        }
539    }
540
541    fn render_footer(
542        &self,
543        _: &mut Window,
544        cx: &mut Context<Picker<Self>>,
545    ) -> Option<gpui::AnyElement> {
546        use feature_flags::FeatureFlagAppExt;
547
548        let plan = proto::Plan::ZedPro;
549
550        Some(
551            h_flex()
552                .w_full()
553                .border_t_1()
554                .border_color(cx.theme().colors().border_variant)
555                .p_1()
556                .gap_4()
557                .justify_between()
558                .when(cx.has_flag::<ZedPro>(), |this| {
559                    this.child(match plan {
560                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
561                            .icon(IconName::ZedAssistant)
562                            .icon_size(IconSize::Small)
563                            .icon_color(Color::Muted)
564                            .icon_position(IconPosition::Start)
565                            .on_click(|_, window, cx| {
566                                window
567                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
568                            }),
569                        Plan::Free | Plan::ZedProTrial => Button::new(
570                            "try-pro",
571                            if plan == Plan::ZedProTrial {
572                                "Upgrade to Pro"
573                            } else {
574                                "Try Pro"
575                            },
576                        )
577                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
578                    })
579                })
580                .child(
581                    Button::new("configure", "Configure")
582                        .icon(IconName::Settings)
583                        .icon_size(IconSize::Small)
584                        .icon_color(Color::Muted)
585                        .icon_position(IconPosition::Start)
586                        .on_click(|_, window, cx| {
587                            let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
588                                zed_actions::agent::OpenConfiguration.boxed_clone()
589                            } else {
590                                zed_actions::assistant::ShowConfiguration.boxed_clone()
591                            };
592
593                            window.dispatch_action(configure_action, cx);
594                        }),
595                )
596                .into_any(),
597        )
598    }
599}