language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    action_with_deprecated_aliases, Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity,
  6    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
  7};
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
 10};
 11use picker::{Picker, PickerDelegate};
 12use proto::Plan;
 13use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 14use workspace::ShowConfiguration;
 15
 16action_with_deprecated_aliases!(
 17    assistant,
 18    ToggleModelSelector,
 19    ["assistant2::ToggleModelSelector"]
 20);
 21
 22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 23
 24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 25
 26pub struct LanguageModelSelector {
 27    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 28    /// The task used to update the picker's matches when there is a change to
 29    /// the language model registry.
 30    update_matches_task: Option<Task<()>>,
 31    _authenticate_all_providers_task: Task<()>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl LanguageModelSelector {
 36    pub fn new(
 37        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 38        window: &mut Window,
 39        cx: &mut Context<Self>,
 40    ) -> Self {
 41        let on_model_changed = Arc::new(on_model_changed);
 42
 43        let all_models = Self::all_models(cx);
 44        let delegate = LanguageModelPickerDelegate {
 45            language_model_selector: cx.entity().downgrade(),
 46            on_model_changed: on_model_changed.clone(),
 47            all_models: all_models.clone(),
 48            filtered_models: all_models,
 49            selected_index: Self::get_active_model_index(cx),
 50        };
 51
 52        let picker = cx.new(|cx| {
 53            Picker::uniform_list(delegate, window, cx)
 54                .show_scrollbar(true)
 55                .width(rems(20.))
 56                .max_height(Some(rems(20.).into()))
 57        });
 58
 59        LanguageModelSelector {
 60            picker,
 61            update_matches_task: None,
 62            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 63            _subscriptions: vec![cx.subscribe_in(
 64                &LanguageModelRegistry::global(cx),
 65                window,
 66                Self::handle_language_model_registry_event,
 67            )],
 68        }
 69    }
 70
 71    fn handle_language_model_registry_event(
 72        &mut self,
 73        _registry: &Entity<LanguageModelRegistry>,
 74        event: &language_model::Event,
 75        window: &mut Window,
 76        cx: &mut Context<Self>,
 77    ) {
 78        match event {
 79            language_model::Event::ProviderStateChanged
 80            | language_model::Event::AddedProvider(_)
 81            | language_model::Event::RemovedProvider(_) => {
 82                let task = self.picker.update(cx, |this, cx| {
 83                    let query = this.query(cx);
 84                    this.delegate.all_models = Self::all_models(cx);
 85                    this.delegate.update_matches(query, window, cx)
 86                });
 87                self.update_matches_task = Some(task);
 88            }
 89            _ => {}
 90        }
 91    }
 92
 93    /// Authenticates all providers in the [`LanguageModelRegistry`].
 94    ///
 95    /// We do this so that we can populate the language selector with all of the
 96    /// models from the configured providers.
 97    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
 98        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 99            .read(cx)
100            .providers()
101            .iter()
102            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
103            .collect::<Vec<_>>();
104
105        cx.spawn(async move |_cx| {
106            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
107                if let Err(err) = authenticate_task.await {
108                    if matches!(err, AuthenticateError::CredentialsNotFound) {
109                        // Since we're authenticating these providers in the
110                        // background for the purposes of populating the
111                        // language selector, we don't care about providers
112                        // where the credentials are not found.
113                    } else {
114                        // Some providers have noisy failure states that we
115                        // don't want to spam the logs with every time the
116                        // language model selector is initialized.
117                        //
118                        // Ideally these should have more clear failure modes
119                        // that we know are safe to ignore here, like what we do
120                        // with `CredentialsNotFound` above.
121                        match provider_id.0.as_ref() {
122                            "lmstudio" | "ollama" => {
123                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
124                                //
125                                // These fail noisily, so we don't log them.
126                            }
127                            "copilot_chat" => {
128                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
129                            }
130                            _ => {
131                                log::error!(
132                                    "Failed to authenticate provider: {}: {err}",
133                                    provider_name.0
134                                );
135                            }
136                        }
137                    }
138                }
139            }
140        })
141    }
142
143    fn all_models(cx: &App) -> Vec<ModelInfo> {
144        LanguageModelRegistry::global(cx)
145            .read(cx)
146            .providers()
147            .iter()
148            .flat_map(|provider| {
149                let icon = provider.icon();
150
151                provider.provided_models(cx).into_iter().map(move |model| {
152                    let model = model.clone();
153                    let icon = model.icon().unwrap_or(icon);
154
155                    ModelInfo {
156                        model: model.clone(),
157                        icon,
158                        availability: model.availability(),
159                    }
160                })
161            })
162            .collect::<Vec<_>>()
163    }
164
165    fn get_active_model_index(cx: &App) -> usize {
166        let active_model = LanguageModelRegistry::read_global(cx).active_model();
167        Self::all_models(cx)
168            .iter()
169            .position(|model_info| {
170                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
171            })
172            .unwrap_or(0)
173    }
174}
175
176impl EventEmitter<DismissEvent> for LanguageModelSelector {}
177
178impl Focusable for LanguageModelSelector {
179    fn focus_handle(&self, cx: &App) -> FocusHandle {
180        self.picker.focus_handle(cx)
181    }
182}
183
184impl Render for LanguageModelSelector {
185    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
186        self.picker.clone()
187    }
188}
189
190#[derive(IntoElement)]
191pub struct LanguageModelSelectorPopoverMenu<T, TT>
192where
193    T: PopoverTrigger + ButtonCommon,
194    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
195{
196    language_model_selector: Entity<LanguageModelSelector>,
197    trigger: T,
198    tooltip: TT,
199    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
200    anchor: Corner,
201}
202
203impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
204where
205    T: PopoverTrigger + ButtonCommon,
206    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
207{
208    pub fn new(
209        language_model_selector: Entity<LanguageModelSelector>,
210        trigger: T,
211        tooltip: TT,
212        anchor: Corner,
213    ) -> Self {
214        Self {
215            language_model_selector,
216            trigger,
217            tooltip,
218            handle: None,
219            anchor,
220        }
221    }
222
223    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
224        self.handle = Some(handle);
225        self
226    }
227}
228
229impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
230where
231    T: PopoverTrigger + ButtonCommon,
232    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
233{
234    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
235        let language_model_selector = self.language_model_selector.clone();
236
237        PopoverMenu::new("model-switcher")
238            .menu(move |_window, _cx| Some(language_model_selector.clone()))
239            .trigger_with_tooltip(self.trigger, self.tooltip)
240            .anchor(self.anchor)
241            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
242            .offset(gpui::Point {
243                x: px(0.0),
244                y: px(-2.0),
245            })
246    }
247}
248
249#[derive(Clone)]
250struct ModelInfo {
251    model: Arc<dyn LanguageModel>,
252    icon: IconName,
253    availability: LanguageModelAvailability,
254}
255
256pub struct LanguageModelPickerDelegate {
257    language_model_selector: WeakEntity<LanguageModelSelector>,
258    on_model_changed: OnModelChanged,
259    all_models: Vec<ModelInfo>,
260    filtered_models: Vec<ModelInfo>,
261    selected_index: usize,
262}
263
264impl PickerDelegate for LanguageModelPickerDelegate {
265    type ListItem = ListItem;
266
267    fn match_count(&self) -> usize {
268        self.filtered_models.len()
269    }
270
271    fn selected_index(&self) -> usize {
272        self.selected_index
273    }
274
275    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
276        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
277        cx.notify();
278    }
279
280    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
281        "Select a model...".into()
282    }
283
284    fn update_matches(
285        &mut self,
286        query: String,
287        window: &mut Window,
288        cx: &mut Context<Picker<Self>>,
289    ) -> Task<()> {
290        let all_models = self.all_models.clone();
291        let current_index = self.selected_index;
292
293        let language_model_registry = LanguageModelRegistry::global(cx);
294
295        let configured_providers = language_model_registry
296            .read(cx)
297            .providers()
298            .iter()
299            .filter(|provider| provider.is_authenticated(cx))
300            .map(|provider| provider.id())
301            .collect::<Vec<_>>();
302
303        cx.spawn_in(window, async move |this, cx| {
304            let filtered_models = cx
305                .background_spawn(async move {
306                    let displayed_models = if configured_providers.is_empty() {
307                        all_models
308                    } else {
309                        all_models
310                            .into_iter()
311                            .filter(|model_info| {
312                                configured_providers.contains(&model_info.model.provider_id())
313                            })
314                            .collect::<Vec<_>>()
315                    };
316
317                    if query.is_empty() {
318                        displayed_models
319                    } else {
320                        displayed_models
321                            .into_iter()
322                            .filter(|model_info| {
323                                model_info
324                                    .model
325                                    .name()
326                                    .0
327                                    .to_lowercase()
328                                    .contains(&query.to_lowercase())
329                            })
330                            .collect()
331                    }
332                })
333                .await;
334
335            this.update_in(cx, |this, window, cx| {
336                this.delegate.filtered_models = filtered_models;
337                // Preserve selection focus
338                let new_index = if current_index >= this.delegate.filtered_models.len() {
339                    0
340                } else {
341                    current_index
342                };
343                this.delegate.set_selected_index(new_index, window, cx);
344                cx.notify();
345            })
346            .ok();
347        })
348    }
349
350    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
351        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
352            let model = model_info.model.clone();
353            (self.on_model_changed)(model.clone(), cx);
354
355            let current_index = self.selected_index;
356            self.set_selected_index(current_index, window, cx);
357
358            cx.emit(DismissEvent);
359        }
360    }
361
362    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
363        self.language_model_selector
364            .update(cx, |_this, cx| cx.emit(DismissEvent))
365            .ok();
366    }
367
368    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
369        let configured_models_count = LanguageModelRegistry::global(cx)
370            .read(cx)
371            .providers()
372            .iter()
373            .filter(|provider| provider.is_authenticated(cx))
374            .count();
375
376        if configured_models_count > 0 {
377            Some(
378                Label::new("Configured Models")
379                    .size(LabelSize::Small)
380                    .color(Color::Muted)
381                    .mt_1()
382                    .mb_0p5()
383                    .ml_2()
384                    .into_any_element(),
385            )
386        } else {
387            None
388        }
389    }
390
391    fn render_match(
392        &self,
393        ix: usize,
394        selected: bool,
395        _: &mut Window,
396        cx: &mut Context<Picker<Self>>,
397    ) -> Option<Self::ListItem> {
398        use feature_flags::FeatureFlagAppExt;
399        let show_badges = cx.has_flag::<ZedPro>();
400
401        let model_info = self.filtered_models.get(ix)?;
402        let provider_name: String = model_info.model.provider_name().0.clone().into();
403
404        let active_provider_id = LanguageModelRegistry::read_global(cx)
405            .active_provider()
406            .map(|m| m.id());
407
408        let active_model_id = LanguageModelRegistry::read_global(cx)
409            .active_model()
410            .map(|m| m.id());
411
412        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
413            && Some(model_info.model.id()) == active_model_id;
414
415        let model_icon_color = if is_selected {
416            Color::Accent
417        } else {
418            Color::Muted
419        };
420
421        Some(
422            ListItem::new(ix)
423                .inset(true)
424                .spacing(ListItemSpacing::Sparse)
425                .toggle_state(selected)
426                .start_slot(
427                    Icon::new(model_info.icon)
428                        .color(model_icon_color)
429                        .size(IconSize::Small),
430                )
431                .child(
432                    h_flex()
433                        .w_full()
434                        .items_center()
435                        .gap_1p5()
436                        .pl_0p5()
437                        .w(px(240.))
438                        .child(
439                            div()
440                                .max_w_40()
441                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
442                        )
443                        .child(
444                            h_flex()
445                                .gap_0p5()
446                                .child(
447                                    Label::new(provider_name)
448                                        .size(LabelSize::XSmall)
449                                        .color(Color::Muted),
450                                )
451                                .children(match model_info.availability {
452                                    LanguageModelAvailability::Public => None,
453                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
454                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
455                                        show_badges.then(|| {
456                                            Label::new("Pro")
457                                                .size(LabelSize::XSmall)
458                                                .color(Color::Muted)
459                                        })
460                                    }
461                                }),
462                        ),
463                )
464                .end_slot(div().pr_3().when(is_selected, |this| {
465                    this.child(
466                        Icon::new(IconName::Check)
467                            .color(Color::Accent)
468                            .size(IconSize::Small),
469                    )
470                })),
471        )
472    }
473
474    fn render_footer(
475        &self,
476        _: &mut Window,
477        cx: &mut Context<Picker<Self>>,
478    ) -> Option<gpui::AnyElement> {
479        use feature_flags::FeatureFlagAppExt;
480
481        let plan = proto::Plan::ZedPro;
482        let is_trial = false;
483
484        Some(
485            h_flex()
486                .w_full()
487                .border_t_1()
488                .border_color(cx.theme().colors().border_variant)
489                .p_1()
490                .gap_4()
491                .justify_between()
492                .when(cx.has_flag::<ZedPro>(), |this| {
493                    this.child(match plan {
494                        // Already a Zed Pro subscriber
495                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
496                            .icon(IconName::ZedAssistant)
497                            .icon_size(IconSize::Small)
498                            .icon_color(Color::Muted)
499                            .icon_position(IconPosition::Start)
500                            .on_click(|_, window, cx| {
501                                window
502                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
503                            }),
504                        // Free user
505                        Plan::Free => Button::new(
506                            "try-pro",
507                            if is_trial {
508                                "Upgrade to Pro"
509                            } else {
510                                "Try Pro"
511                            },
512                        )
513                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
514                    })
515                })
516                .child(
517                    Button::new("configure", "Configure")
518                        .icon(IconName::Settings)
519                        .icon_size(IconSize::Small)
520                        .icon_color(Color::Muted)
521                        .icon_position(IconPosition::Start)
522                        .on_click(|_, window, cx| {
523                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
524                        }),
525                )
526                .into_any(),
527        )
528    }
529}