language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::{Assistant2FeatureFlag, ZedPro};
  4use gpui::{
  5    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  6    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  7};
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
 10};
 11use picker::{Picker, PickerDelegate};
 12use proto::Plan;
 13use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 14
 15action_with_deprecated_aliases!(
 16    assistant,
 17    ToggleModelSelector,
 18    ["assistant2::ToggleModelSelector"]
 19);
 20
 21const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 22
 23type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 24
 25pub struct LanguageModelSelector {
 26    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 27    /// The task used to update the picker's matches when there is a change to
 28    /// the language model registry.
 29    update_matches_task: Option<Task<()>>,
 30    _authenticate_all_providers_task: Task<()>,
 31    _subscriptions: Vec<Subscription>,
 32}
 33
 34impl LanguageModelSelector {
 35    pub fn new(
 36        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 37        window: &mut Window,
 38        cx: &mut Context<Self>,
 39    ) -> Self {
 40        let on_model_changed = Arc::new(on_model_changed);
 41
 42        let all_models = Self::all_models(cx);
 43        let delegate = LanguageModelPickerDelegate {
 44            language_model_selector: cx.entity().downgrade(),
 45            on_model_changed: on_model_changed.clone(),
 46            all_models: all_models.clone(),
 47            filtered_models: all_models,
 48            selected_index: Self::get_active_model_index(cx),
 49        };
 50
 51        let picker = cx.new(|cx| {
 52            Picker::uniform_list(delegate, window, cx)
 53                .show_scrollbar(true)
 54                .width(rems(20.))
 55                .max_height(Some(rems(20.).into()))
 56        });
 57
 58        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 59
 60        LanguageModelSelector {
 61            picker,
 62            update_matches_task: None,
 63            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 64            _subscriptions: vec![
 65                cx.subscribe_in(
 66                    &LanguageModelRegistry::global(cx),
 67                    window,
 68                    Self::handle_language_model_registry_event,
 69                ),
 70                subscription,
 71            ],
 72        }
 73    }
 74
 75    fn handle_language_model_registry_event(
 76        &mut self,
 77        _registry: &Entity<LanguageModelRegistry>,
 78        event: &language_model::Event,
 79        window: &mut Window,
 80        cx: &mut Context<Self>,
 81    ) {
 82        match event {
 83            language_model::Event::ProviderStateChanged
 84            | language_model::Event::AddedProvider(_)
 85            | language_model::Event::RemovedProvider(_) => {
 86                let task = self.picker.update(cx, |this, cx| {
 87                    let query = this.query(cx);
 88                    this.delegate.all_models = Self::all_models(cx);
 89                    this.delegate.update_matches(query, window, cx)
 90                });
 91                self.update_matches_task = Some(task);
 92            }
 93            _ => {}
 94        }
 95    }
 96
 97    /// Authenticates all providers in the [`LanguageModelRegistry`].
 98    ///
 99    /// We do this so that we can populate the language selector with all of the
100    /// models from the configured providers.
101    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
102        let authenticate_all_providers = LanguageModelRegistry::global(cx)
103            .read(cx)
104            .providers()
105            .iter()
106            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
107            .collect::<Vec<_>>();
108
109        cx.spawn(async move |_cx| {
110            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
111                if let Err(err) = authenticate_task.await {
112                    if matches!(err, AuthenticateError::CredentialsNotFound) {
113                        // Since we're authenticating these providers in the
114                        // background for the purposes of populating the
115                        // language selector, we don't care about providers
116                        // where the credentials are not found.
117                    } else {
118                        // Some providers have noisy failure states that we
119                        // don't want to spam the logs with every time the
120                        // language model selector is initialized.
121                        //
122                        // Ideally these should have more clear failure modes
123                        // that we know are safe to ignore here, like what we do
124                        // with `CredentialsNotFound` above.
125                        match provider_id.0.as_ref() {
126                            "lmstudio" | "ollama" => {
127                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
128                                //
129                                // These fail noisily, so we don't log them.
130                            }
131                            "copilot_chat" => {
132                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
133                            }
134                            _ => {
135                                log::error!(
136                                    "Failed to authenticate provider: {}: {err}",
137                                    provider_name.0
138                                );
139                            }
140                        }
141                    }
142                }
143            }
144        })
145    }
146
147    fn all_models(cx: &App) -> Vec<ModelInfo> {
148        LanguageModelRegistry::global(cx)
149            .read(cx)
150            .providers()
151            .iter()
152            .flat_map(|provider| {
153                let icon = provider.icon();
154
155                provider.provided_models(cx).into_iter().map(move |model| {
156                    let model = model.clone();
157                    let icon = model.icon().unwrap_or(icon);
158
159                    ModelInfo {
160                        model: model.clone(),
161                        icon,
162                        availability: model.availability(),
163                    }
164                })
165            })
166            .collect::<Vec<_>>()
167    }
168
169    fn get_active_model_index(cx: &App) -> usize {
170        let active_model = LanguageModelRegistry::read_global(cx).default_model();
171        Self::all_models(cx)
172            .iter()
173            .position(|model_info| {
174                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
175            })
176            .unwrap_or(0)
177    }
178}
179
180impl EventEmitter<DismissEvent> for LanguageModelSelector {}
181
182impl Focusable for LanguageModelSelector {
183    fn focus_handle(&self, cx: &App) -> FocusHandle {
184        self.picker.focus_handle(cx)
185    }
186}
187
188impl Render for LanguageModelSelector {
189    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
190        self.picker.clone()
191    }
192}
193
194#[derive(IntoElement)]
195pub struct LanguageModelSelectorPopoverMenu<T, TT>
196where
197    T: PopoverTrigger + ButtonCommon,
198    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
199{
200    language_model_selector: Entity<LanguageModelSelector>,
201    trigger: T,
202    tooltip: TT,
203    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
204    anchor: Corner,
205}
206
207impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
208where
209    T: PopoverTrigger + ButtonCommon,
210    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
211{
212    pub fn new(
213        language_model_selector: Entity<LanguageModelSelector>,
214        trigger: T,
215        tooltip: TT,
216        anchor: Corner,
217    ) -> Self {
218        Self {
219            language_model_selector,
220            trigger,
221            tooltip,
222            handle: None,
223            anchor,
224        }
225    }
226
227    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
228        self.handle = Some(handle);
229        self
230    }
231}
232
233impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
234where
235    T: PopoverTrigger + ButtonCommon,
236    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
237{
238    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
239        let language_model_selector = self.language_model_selector.clone();
240
241        PopoverMenu::new("model-switcher")
242            .menu(move |_window, _cx| Some(language_model_selector.clone()))
243            .trigger_with_tooltip(self.trigger, self.tooltip)
244            .anchor(self.anchor)
245            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
246            .offset(gpui::Point {
247                x: px(0.0),
248                y: px(-2.0),
249            })
250    }
251}
252
253#[derive(Clone)]
254struct ModelInfo {
255    model: Arc<dyn LanguageModel>,
256    icon: IconName,
257    availability: LanguageModelAvailability,
258}
259
260pub struct LanguageModelPickerDelegate {
261    language_model_selector: WeakEntity<LanguageModelSelector>,
262    on_model_changed: OnModelChanged,
263    all_models: Vec<ModelInfo>,
264    filtered_models: Vec<ModelInfo>,
265    selected_index: usize,
266}
267
268impl PickerDelegate for LanguageModelPickerDelegate {
269    type ListItem = ListItem;
270
271    fn match_count(&self) -> usize {
272        self.filtered_models.len()
273    }
274
275    fn selected_index(&self) -> usize {
276        self.selected_index
277    }
278
279    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
280        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
281        cx.notify();
282    }
283
284    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
285        "Select a model...".into()
286    }
287
288    fn update_matches(
289        &mut self,
290        query: String,
291        window: &mut Window,
292        cx: &mut Context<Picker<Self>>,
293    ) -> Task<()> {
294        let all_models = self.all_models.clone();
295        let current_index = self.selected_index;
296
297        let language_model_registry = LanguageModelRegistry::global(cx);
298
299        let configured_providers = language_model_registry
300            .read(cx)
301            .providers()
302            .iter()
303            .filter(|provider| provider.is_authenticated(cx))
304            .map(|provider| provider.id())
305            .collect::<Vec<_>>();
306
307        cx.spawn_in(window, async move |this, cx| {
308            let filtered_models = cx
309                .background_spawn(async move {
310                    let displayed_models = if configured_providers.is_empty() {
311                        all_models
312                    } else {
313                        all_models
314                            .into_iter()
315                            .filter(|model_info| {
316                                configured_providers.contains(&model_info.model.provider_id())
317                            })
318                            .collect::<Vec<_>>()
319                    };
320
321                    if query.is_empty() {
322                        displayed_models
323                    } else {
324                        displayed_models
325                            .into_iter()
326                            .filter(|model_info| {
327                                model_info
328                                    .model
329                                    .name()
330                                    .0
331                                    .to_lowercase()
332                                    .contains(&query.to_lowercase())
333                            })
334                            .collect()
335                    }
336                })
337                .await;
338
339            this.update_in(cx, |this, window, cx| {
340                this.delegate.filtered_models = filtered_models;
341                // Preserve selection focus
342                let new_index = if current_index >= this.delegate.filtered_models.len() {
343                    0
344                } else {
345                    current_index
346                };
347                this.delegate.set_selected_index(new_index, window, cx);
348                cx.notify();
349            })
350            .ok();
351        })
352    }
353
354    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
355        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
356            let model = model_info.model.clone();
357            (self.on_model_changed)(model.clone(), cx);
358
359            let current_index = self.selected_index;
360            self.set_selected_index(current_index, window, cx);
361
362            cx.emit(DismissEvent);
363        }
364    }
365
366    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
367        self.language_model_selector
368            .update(cx, |_this, cx| cx.emit(DismissEvent))
369            .ok();
370    }
371
372    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
373        let configured_models_count = LanguageModelRegistry::global(cx)
374            .read(cx)
375            .providers()
376            .iter()
377            .filter(|provider| provider.is_authenticated(cx))
378            .count();
379
380        if configured_models_count > 0 {
381            Some(
382                Label::new("Configured Models")
383                    .size(LabelSize::Small)
384                    .color(Color::Muted)
385                    .mt_1()
386                    .mb_0p5()
387                    .ml_2()
388                    .into_any_element(),
389            )
390        } else {
391            None
392        }
393    }
394
395    fn render_match(
396        &self,
397        ix: usize,
398        selected: bool,
399        _: &mut Window,
400        cx: &mut Context<Picker<Self>>,
401    ) -> Option<Self::ListItem> {
402        use feature_flags::FeatureFlagAppExt;
403        let show_badges = cx.has_flag::<ZedPro>();
404
405        let model_info = self.filtered_models.get(ix)?;
406        let provider_name: String = model_info.model.provider_name().0.clone().into();
407
408        let active_model = LanguageModelRegistry::read_global(cx).default_model();
409
410        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
411        let active_model_id = active_model.map(|m| m.model.id());
412
413        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
414            && Some(model_info.model.id()) == active_model_id;
415
416        let model_icon_color = if is_selected {
417            Color::Accent
418        } else {
419            Color::Muted
420        };
421
422        Some(
423            ListItem::new(ix)
424                .inset(true)
425                .spacing(ListItemSpacing::Sparse)
426                .toggle_state(selected)
427                .start_slot(
428                    Icon::new(model_info.icon)
429                        .color(model_icon_color)
430                        .size(IconSize::Small),
431                )
432                .child(
433                    h_flex()
434                        .w_full()
435                        .items_center()
436                        .gap_1p5()
437                        .pl_0p5()
438                        .w(px(240.))
439                        .child(
440                            div()
441                                .max_w_40()
442                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
443                        )
444                        .child(
445                            h_flex()
446                                .gap_0p5()
447                                .child(
448                                    Label::new(provider_name)
449                                        .size(LabelSize::XSmall)
450                                        .color(Color::Muted),
451                                )
452                                .children(match model_info.availability {
453                                    LanguageModelAvailability::Public => None,
454                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
455                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
456                                        show_badges.then(|| {
457                                            Label::new("Pro")
458                                                .size(LabelSize::XSmall)
459                                                .color(Color::Muted)
460                                        })
461                                    }
462                                }),
463                        ),
464                )
465                .end_slot(div().pr_3().when(is_selected, |this| {
466                    this.child(
467                        Icon::new(IconName::Check)
468                            .color(Color::Accent)
469                            .size(IconSize::Small),
470                    )
471                })),
472        )
473    }
474
475    fn render_footer(
476        &self,
477        _: &mut Window,
478        cx: &mut Context<Picker<Self>>,
479    ) -> Option<gpui::AnyElement> {
480        use feature_flags::FeatureFlagAppExt;
481
482        let plan = proto::Plan::ZedPro;
483        let is_trial = false;
484
485        Some(
486            h_flex()
487                .w_full()
488                .border_t_1()
489                .border_color(cx.theme().colors().border_variant)
490                .p_1()
491                .gap_4()
492                .justify_between()
493                .when(cx.has_flag::<ZedPro>(), |this| {
494                    this.child(match plan {
495                        // Already a Zed Pro subscriber
496                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
497                            .icon(IconName::ZedAssistant)
498                            .icon_size(IconSize::Small)
499                            .icon_color(Color::Muted)
500                            .icon_position(IconPosition::Start)
501                            .on_click(|_, window, cx| {
502                                window
503                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
504                            }),
505                        // Free user
506                        Plan::Free => Button::new(
507                            "try-pro",
508                            if is_trial {
509                                "Upgrade to Pro"
510                            } else {
511                                "Try Pro"
512                            },
513                        )
514                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
515                    })
516                })
517                .child(
518                    Button::new("configure", "Configure")
519                        .icon(IconName::Settings)
520                        .icon_size(IconSize::Small)
521                        .icon_color(Color::Muted)
522                        .icon_position(IconPosition::Start)
523                        .on_click(|_, window, cx| {
524                            let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
525                                zed_actions::agent::OpenConfiguration.boxed_clone()
526                            } else {
527                                zed_actions::assistant::ShowConfiguration.boxed_clone()
528                            };
529
530                            window.dispatch_action(configure_action, cx);
531                        }),
532                )
533                .into_any(),
534        )
535    }
536}