language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    action_with_deprecated_aliases, Action, AnyElement, App, Corner, DismissEvent, Entity,
  6    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
  7};
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
 10};
 11use picker::{Picker, PickerDelegate};
 12use proto::Plan;
 13use ui::{
 14    prelude::*, ButtonLike, IconButtonShape, ListItem, ListItemSpacing, PopoverButton,
 15    PopoverMenuHandle, Tooltip, TriggerablePopover,
 16};
 17use workspace::ShowConfiguration;
 18
 19action_with_deprecated_aliases!(
 20    assistant,
 21    ToggleModelSelector,
 22    ["assistant2::ToggleModelSelector"]
 23);
 24
 25const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 26
 27type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 28
 29pub struct LanguageModelSelector {
 30    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 31    /// The task used to update the picker's matches when there is a change to
 32    /// the language model registry.
 33    update_matches_task: Option<Task<()>>,
 34    popover_menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 35    _authenticate_all_providers_task: Task<()>,
 36    _subscriptions: Vec<Subscription>,
 37}
 38
 39impl LanguageModelSelector {
 40    pub fn new(
 41        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 42        window: &mut Window,
 43        cx: &mut Context<Self>,
 44    ) -> Self {
 45        let on_model_changed = Arc::new(on_model_changed);
 46
 47        let all_models = Self::all_models(cx);
 48        let delegate = LanguageModelPickerDelegate {
 49            language_model_selector: cx.entity().downgrade(),
 50            on_model_changed: on_model_changed.clone(),
 51            all_models: all_models.clone(),
 52            filtered_models: all_models,
 53            selected_index: Self::get_active_model_index(cx),
 54        };
 55
 56        let picker = cx.new(|cx| {
 57            Picker::uniform_list(delegate, window, cx)
 58                .show_scrollbar(true)
 59                .width(rems(20.))
 60                .max_height(Some(rems(20.).into()))
 61        });
 62
 63        LanguageModelSelector {
 64            picker,
 65            update_matches_task: None,
 66            popover_menu_handle: PopoverMenuHandle::default(),
 67            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 68            _subscriptions: vec![cx.subscribe_in(
 69                &LanguageModelRegistry::global(cx),
 70                window,
 71                Self::handle_language_model_registry_event,
 72            )],
 73        }
 74    }
 75
 76    pub fn toggle_model_selector(
 77        &mut self,
 78        _: &ToggleModelSelector,
 79        window: &mut Window,
 80        cx: &mut Context<Self>,
 81    ) {
 82        self.popover_menu_handle.toggle(window, cx);
 83    }
 84
 85    fn handle_language_model_registry_event(
 86        &mut self,
 87        _registry: &Entity<LanguageModelRegistry>,
 88        event: &language_model::Event,
 89        window: &mut Window,
 90        cx: &mut Context<Self>,
 91    ) {
 92        match event {
 93            language_model::Event::ProviderStateChanged
 94            | language_model::Event::AddedProvider(_)
 95            | language_model::Event::RemovedProvider(_) => {
 96                let task = self.picker.update(cx, |this, cx| {
 97                    let query = this.query(cx);
 98                    this.delegate.all_models = Self::all_models(cx);
 99                    this.delegate.update_matches(query, window, cx)
100                });
101                self.update_matches_task = Some(task);
102            }
103            _ => {}
104        }
105    }
106
107    /// Authenticates all providers in the [`LanguageModelRegistry`].
108    ///
109    /// We do this so that we can populate the language selector with all of the
110    /// models from the configured providers.
111    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112        let authenticate_all_providers = LanguageModelRegistry::global(cx)
113            .read(cx)
114            .providers()
115            .iter()
116            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117            .collect::<Vec<_>>();
118
119        cx.spawn(|_cx| async move {
120            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121                if let Err(err) = authenticate_task.await {
122                    if matches!(err, AuthenticateError::CredentialsNotFound) {
123                        // Since we're authenticating these providers in the
124                        // background for the purposes of populating the
125                        // language selector, we don't care about providers
126                        // where the credentials are not found.
127                    } else {
128                        // Some providers have noisy failure states that we
129                        // don't want to spam the logs with every time the
130                        // language model selector is initialized.
131                        //
132                        // Ideally these should have more clear failure modes
133                        // that we know are safe to ignore here, like what we do
134                        // with `CredentialsNotFound` above.
135                        match provider_id.0.as_ref() {
136                            "lmstudio" | "ollama" => {
137                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138                                //
139                                // These fail noisily, so we don't log them.
140                            }
141                            "copilot_chat" => {
142                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143                            }
144                            _ => {
145                                log::error!(
146                                    "Failed to authenticate provider: {}: {err}",
147                                    provider_name.0
148                                );
149                            }
150                        }
151                    }
152                }
153            }
154        })
155    }
156
157    fn all_models(cx: &App) -> Vec<ModelInfo> {
158        LanguageModelRegistry::global(cx)
159            .read(cx)
160            .providers()
161            .iter()
162            .flat_map(|provider| {
163                let icon = provider.icon();
164
165                provider.provided_models(cx).into_iter().map(move |model| {
166                    let model = model.clone();
167                    let icon = model.icon().unwrap_or(icon);
168
169                    ModelInfo {
170                        model: model.clone(),
171                        icon,
172                        availability: model.availability(),
173                    }
174                })
175            })
176            .collect::<Vec<_>>()
177    }
178
179    fn get_active_model_index(cx: &App) -> usize {
180        let active_model = LanguageModelRegistry::read_global(cx).active_model();
181        Self::all_models(cx)
182            .iter()
183            .position(|model_info| {
184                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
185            })
186            .unwrap_or(0)
187    }
188}
189
190impl EventEmitter<DismissEvent> for LanguageModelSelector {}
191
192impl Focusable for LanguageModelSelector {
193    fn focus_handle(&self, cx: &App) -> FocusHandle {
194        self.picker.focus_handle(cx)
195    }
196}
197
198impl Render for LanguageModelSelector {
199    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
200        self.picker.clone()
201    }
202}
203
204impl TriggerablePopover for LanguageModelSelector {
205    fn menu_handle(
206        &mut self,
207        _window: &mut Window,
208        _cx: &mut gpui::Context<Self>,
209    ) -> PopoverMenuHandle<Self> {
210        self.popover_menu_handle.clone()
211    }
212}
213
214#[derive(Clone)]
215struct ModelInfo {
216    model: Arc<dyn LanguageModel>,
217    icon: IconName,
218    availability: LanguageModelAvailability,
219}
220
221pub struct LanguageModelPickerDelegate {
222    language_model_selector: WeakEntity<LanguageModelSelector>,
223    on_model_changed: OnModelChanged,
224    all_models: Vec<ModelInfo>,
225    filtered_models: Vec<ModelInfo>,
226    selected_index: usize,
227}
228
229impl PickerDelegate for LanguageModelPickerDelegate {
230    type ListItem = ListItem;
231
232    fn match_count(&self) -> usize {
233        self.filtered_models.len()
234    }
235
236    fn selected_index(&self) -> usize {
237        self.selected_index
238    }
239
240    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
241        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
242        cx.notify();
243    }
244
245    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
246        "Select a model...".into()
247    }
248
249    fn update_matches(
250        &mut self,
251        query: String,
252        window: &mut Window,
253        cx: &mut Context<Picker<Self>>,
254    ) -> Task<()> {
255        let all_models = self.all_models.clone();
256        let current_index = self.selected_index;
257
258        let language_model_registry = LanguageModelRegistry::global(cx);
259
260        let configured_providers = language_model_registry
261            .read(cx)
262            .providers()
263            .iter()
264            .filter(|provider| provider.is_authenticated(cx))
265            .map(|provider| provider.id())
266            .collect::<Vec<_>>();
267
268        cx.spawn_in(window, |this, mut cx| async move {
269            let filtered_models = cx
270                .background_spawn(async move {
271                    let displayed_models = if configured_providers.is_empty() {
272                        all_models
273                    } else {
274                        all_models
275                            .into_iter()
276                            .filter(|model_info| {
277                                configured_providers.contains(&model_info.model.provider_id())
278                            })
279                            .collect::<Vec<_>>()
280                    };
281
282                    if query.is_empty() {
283                        displayed_models
284                    } else {
285                        displayed_models
286                            .into_iter()
287                            .filter(|model_info| {
288                                model_info
289                                    .model
290                                    .name()
291                                    .0
292                                    .to_lowercase()
293                                    .contains(&query.to_lowercase())
294                            })
295                            .collect()
296                    }
297                })
298                .await;
299
300            this.update_in(&mut cx, |this, window, cx| {
301                this.delegate.filtered_models = filtered_models;
302                // Preserve selection focus
303                let new_index = if current_index >= this.delegate.filtered_models.len() {
304                    0
305                } else {
306                    current_index
307                };
308                this.delegate.set_selected_index(new_index, window, cx);
309                cx.notify();
310            })
311            .ok();
312        })
313    }
314
315    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
316        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
317            let model = model_info.model.clone();
318            (self.on_model_changed)(model.clone(), cx);
319
320            let current_index = self.selected_index;
321            self.set_selected_index(current_index, window, cx);
322
323            cx.emit(DismissEvent);
324        }
325    }
326
327    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
328        self.language_model_selector
329            .update(cx, |_this, cx| cx.emit(DismissEvent))
330            .ok();
331    }
332
333    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
334        let configured_models_count = LanguageModelRegistry::global(cx)
335            .read(cx)
336            .providers()
337            .iter()
338            .filter(|provider| provider.is_authenticated(cx))
339            .count();
340
341        if configured_models_count > 0 {
342            Some(
343                Label::new("Configured Models")
344                    .size(LabelSize::Small)
345                    .color(Color::Muted)
346                    .mt_1()
347                    .mb_0p5()
348                    .ml_2()
349                    .into_any_element(),
350            )
351        } else {
352            None
353        }
354    }
355
356    fn render_match(
357        &self,
358        ix: usize,
359        selected: bool,
360        _: &mut Window,
361        cx: &mut Context<Picker<Self>>,
362    ) -> Option<Self::ListItem> {
363        use feature_flags::FeatureFlagAppExt;
364        let show_badges = cx.has_flag::<ZedPro>();
365
366        let model_info = self.filtered_models.get(ix)?;
367        let provider_name: String = model_info.model.provider_name().0.clone().into();
368
369        let active_provider_id = LanguageModelRegistry::read_global(cx)
370            .active_provider()
371            .map(|m| m.id());
372
373        let active_model_id = LanguageModelRegistry::read_global(cx)
374            .active_model()
375            .map(|m| m.id());
376
377        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
378            && Some(model_info.model.id()) == active_model_id;
379
380        let model_icon_color = if is_selected {
381            Color::Accent
382        } else {
383            Color::Muted
384        };
385
386        Some(
387            ListItem::new(ix)
388                .inset(true)
389                .spacing(ListItemSpacing::Sparse)
390                .toggle_state(selected)
391                .start_slot(
392                    Icon::new(model_info.icon)
393                        .color(model_icon_color)
394                        .size(IconSize::Small),
395                )
396                .child(
397                    h_flex()
398                        .w_full()
399                        .items_center()
400                        .gap_1p5()
401                        .pl_0p5()
402                        .w(px(240.))
403                        .child(
404                            div().max_w_40().child(
405                                Label::new(model_info.model.name().0.clone()).text_ellipsis(),
406                            ),
407                        )
408                        .child(
409                            h_flex()
410                                .gap_0p5()
411                                .child(
412                                    Label::new(provider_name)
413                                        .size(LabelSize::XSmall)
414                                        .color(Color::Muted),
415                                )
416                                .children(match model_info.availability {
417                                    LanguageModelAvailability::Public => None,
418                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
419                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
420                                        show_badges.then(|| {
421                                            Label::new("Pro")
422                                                .size(LabelSize::XSmall)
423                                                .color(Color::Muted)
424                                        })
425                                    }
426                                }),
427                        ),
428                )
429                .end_slot(div().pr_3().when(is_selected, |this| {
430                    this.child(
431                        Icon::new(IconName::Check)
432                            .color(Color::Accent)
433                            .size(IconSize::Small),
434                    )
435                })),
436        )
437    }
438
439    fn render_footer(
440        &self,
441        _: &mut Window,
442        cx: &mut Context<Picker<Self>>,
443    ) -> Option<gpui::AnyElement> {
444        use feature_flags::FeatureFlagAppExt;
445
446        let plan = proto::Plan::ZedPro;
447        let is_trial = false;
448
449        Some(
450            h_flex()
451                .w_full()
452                .border_t_1()
453                .border_color(cx.theme().colors().border_variant)
454                .p_1()
455                .gap_4()
456                .justify_between()
457                .when(cx.has_flag::<ZedPro>(), |this| {
458                    this.child(match plan {
459                        // Already a Zed Pro subscriber
460                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
461                            .icon(IconName::ZedAssistant)
462                            .icon_size(IconSize::Small)
463                            .icon_color(Color::Muted)
464                            .icon_position(IconPosition::Start)
465                            .on_click(|_, window, cx| {
466                                window
467                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
468                            }),
469                        // Free user
470                        Plan::Free => Button::new(
471                            "try-pro",
472                            if is_trial {
473                                "Upgrade to Pro"
474                            } else {
475                                "Try Pro"
476                            },
477                        )
478                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
479                    })
480                })
481                .child(
482                    Button::new("configure", "Configure")
483                        .icon(IconName::Settings)
484                        .icon_size(IconSize::Small)
485                        .icon_color(Color::Muted)
486                        .icon_position(IconPosition::Start)
487                        .on_click(|_, window, cx| {
488                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
489                        }),
490                )
491                .into_any(),
492        )
493    }
494}
495
496pub struct InlineLanguageModelSelector {
497    selector: Entity<LanguageModelSelector>,
498}
499
500impl InlineLanguageModelSelector {
501    pub fn new(selector: Entity<LanguageModelSelector>) -> Self {
502        Self { selector }
503    }
504}
505
506impl RenderOnce for InlineLanguageModelSelector {
507    fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
508        PopoverButton::new(
509            self.selector,
510            gpui::Corner::TopRight,
511            IconButton::new("context", IconName::SettingsAlt)
512                .shape(IconButtonShape::Square)
513                .icon_size(IconSize::Small)
514                .icon_color(Color::Muted),
515            move |window, cx| {
516                Tooltip::with_meta(
517                    format!(
518                        "Using {}",
519                        LanguageModelRegistry::read_global(cx)
520                            .active_model()
521                            .map(|model| model.name().0)
522                            .unwrap_or_else(|| "No model selected".into()),
523                    ),
524                    None,
525                    "Change Model",
526                    window,
527                    cx,
528                )
529            },
530        )
531        .render(window, cx)
532    }
533}
534
535pub struct AssistantLanguageModelSelector {
536    focus_handle: FocusHandle,
537    selector: Entity<LanguageModelSelector>,
538}
539
540impl AssistantLanguageModelSelector {
541    pub fn new(focus_handle: FocusHandle, selector: Entity<LanguageModelSelector>) -> Self {
542        Self {
543            focus_handle,
544            selector,
545        }
546    }
547}
548
549impl RenderOnce for AssistantLanguageModelSelector {
550    fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
551        let active_model = LanguageModelRegistry::read_global(cx).active_model();
552        let focus_handle = self.focus_handle.clone();
553        let model_name = match active_model {
554            Some(model) => model.name().0,
555            _ => SharedString::from("No model selected"),
556        };
557
558        PopoverButton::new(
559            self.selector.clone(),
560            Corner::BottomRight,
561            ButtonLike::new("active-model")
562                .style(ButtonStyle::Subtle)
563                .child(
564                    h_flex()
565                        .gap_0p5()
566                        .child(
567                            Label::new(model_name)
568                                .size(LabelSize::Small)
569                                .color(Color::Muted),
570                        )
571                        .child(
572                            Icon::new(IconName::ChevronDown)
573                                .color(Color::Muted)
574                                .size(IconSize::XSmall),
575                        ),
576                ),
577            move |window, cx| {
578                Tooltip::for_action_in(
579                    "Change Model",
580                    &ToggleModelSelector,
581                    &focus_handle,
582                    window,
583                    cx,
584                )
585            },
586        )
587        .render(window, cx)
588    }
589}