language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  6    Focusable, Subscription, Task, WeakEntity,
  7};
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
 10};
 11use picker::{Picker, PickerDelegate};
 12use proto::Plan;
 13use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 14use workspace::ShowConfiguration;
 15
 16const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 17
 18type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 19
 20pub struct LanguageModelSelector {
 21    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 22    /// The task used to update the picker's matches when there is a change to
 23    /// the language model registry.
 24    update_matches_task: Option<Task<()>>,
 25    _authenticate_all_providers_task: Task<()>,
 26    _subscriptions: Vec<Subscription>,
 27}
 28
 29impl LanguageModelSelector {
 30    pub fn new(
 31        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 32        window: &mut Window,
 33        cx: &mut Context<Self>,
 34    ) -> Self {
 35        let on_model_changed = Arc::new(on_model_changed);
 36
 37        let all_models = Self::all_models(cx);
 38        let delegate = LanguageModelPickerDelegate {
 39            language_model_selector: cx.entity().downgrade(),
 40            on_model_changed: on_model_changed.clone(),
 41            all_models: all_models.clone(),
 42            filtered_models: all_models,
 43            selected_index: Self::get_active_model_index(cx),
 44        };
 45
 46        let picker = cx.new(|cx| {
 47            Picker::uniform_list(delegate, window, cx)
 48                .show_scrollbar(true)
 49                .width(rems(20.))
 50                .max_height(Some(rems(20.).into()))
 51        });
 52
 53        LanguageModelSelector {
 54            picker,
 55            update_matches_task: None,
 56            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 57            _subscriptions: vec![cx.subscribe_in(
 58                &LanguageModelRegistry::global(cx),
 59                window,
 60                Self::handle_language_model_registry_event,
 61            )],
 62        }
 63    }
 64
 65    fn handle_language_model_registry_event(
 66        &mut self,
 67        _registry: &Entity<LanguageModelRegistry>,
 68        event: &language_model::Event,
 69        window: &mut Window,
 70        cx: &mut Context<Self>,
 71    ) {
 72        match event {
 73            language_model::Event::ProviderStateChanged
 74            | language_model::Event::AddedProvider(_)
 75            | language_model::Event::RemovedProvider(_) => {
 76                let task = self.picker.update(cx, |this, cx| {
 77                    let query = this.query(cx);
 78                    this.delegate.all_models = Self::all_models(cx);
 79                    this.delegate.update_matches(query, window, cx)
 80                });
 81                self.update_matches_task = Some(task);
 82            }
 83            _ => {}
 84        }
 85    }
 86
 87    /// Authenticates all providers in the [`LanguageModelRegistry`].
 88    ///
 89    /// We do this so that we can populate the language selector with all of the
 90    /// models from the configured providers.
 91    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
 92        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 93            .read(cx)
 94            .providers()
 95            .iter()
 96            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 97            .collect::<Vec<_>>();
 98
 99        cx.spawn(|_cx| async move {
100            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
101                if let Err(err) = authenticate_task.await {
102                    if matches!(err, AuthenticateError::CredentialsNotFound) {
103                        // Since we're authenticating these providers in the
104                        // background for the purposes of populating the
105                        // language selector, we don't care about providers
106                        // where the credentials are not found.
107                    } else {
108                        // Some providers have noisy failure states that we
109                        // don't want to spam the logs with every time the
110                        // language model selector is initialized.
111                        //
112                        // Ideally these should have more clear failure modes
113                        // that we know are safe to ignore here, like what we do
114                        // with `CredentialsNotFound` above.
115                        match provider_id.0.as_ref() {
116                            "lmstudio" | "ollama" => {
117                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
118                                //
119                                // These fail noisily, so we don't log them.
120                            }
121                            "copilot_chat" => {
122                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
123                            }
124                            _ => {
125                                log::error!(
126                                    "Failed to authenticate provider: {}: {err}",
127                                    provider_name.0
128                                );
129                            }
130                        }
131                    }
132                }
133            }
134        })
135    }
136
137    fn all_models(cx: &App) -> Vec<ModelInfo> {
138        LanguageModelRegistry::global(cx)
139            .read(cx)
140            .providers()
141            .iter()
142            .flat_map(|provider| {
143                let icon = provider.icon();
144
145                provider.provided_models(cx).into_iter().map(move |model| {
146                    let model = model.clone();
147                    let icon = model.icon().unwrap_or(icon);
148
149                    ModelInfo {
150                        model: model.clone(),
151                        icon,
152                        availability: model.availability(),
153                    }
154                })
155            })
156            .collect::<Vec<_>>()
157    }
158
159    fn get_active_model_index(cx: &App) -> usize {
160        let active_model = LanguageModelRegistry::read_global(cx).active_model();
161        Self::all_models(cx)
162            .iter()
163            .position(|model_info| {
164                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
165            })
166            .unwrap_or(0)
167    }
168}
169
170impl EventEmitter<DismissEvent> for LanguageModelSelector {}
171
172impl Focusable for LanguageModelSelector {
173    fn focus_handle(&self, cx: &App) -> FocusHandle {
174        self.picker.focus_handle(cx)
175    }
176}
177
178impl Render for LanguageModelSelector {
179    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
180        self.picker.clone()
181    }
182}
183
184#[derive(IntoElement)]
185pub struct LanguageModelSelectorPopoverMenu<T, TT>
186where
187    T: PopoverTrigger + ButtonCommon,
188    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
189{
190    language_model_selector: Entity<LanguageModelSelector>,
191    trigger: T,
192    tooltip: TT,
193    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
194    anchor: Corner,
195}
196
197impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
198where
199    T: PopoverTrigger + ButtonCommon,
200    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
201{
202    pub fn new(
203        language_model_selector: Entity<LanguageModelSelector>,
204        trigger: T,
205        tooltip: TT,
206        anchor: Corner,
207    ) -> Self {
208        Self {
209            language_model_selector,
210            trigger,
211            tooltip,
212            handle: None,
213            anchor,
214        }
215    }
216
217    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
218        self.handle = Some(handle);
219        self
220    }
221}
222
223impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
224where
225    T: PopoverTrigger + ButtonCommon,
226    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
227{
228    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
229        let language_model_selector = self.language_model_selector.clone();
230
231        PopoverMenu::new("model-switcher")
232            .menu(move |_window, _cx| Some(language_model_selector.clone()))
233            .trigger_with_tooltip(self.trigger, self.tooltip)
234            .anchor(self.anchor)
235            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
236            .offset(gpui::Point {
237                x: px(0.0),
238                y: px(-2.0),
239            })
240    }
241}
242
243#[derive(Clone)]
244struct ModelInfo {
245    model: Arc<dyn LanguageModel>,
246    icon: IconName,
247    availability: LanguageModelAvailability,
248}
249
250pub struct LanguageModelPickerDelegate {
251    language_model_selector: WeakEntity<LanguageModelSelector>,
252    on_model_changed: OnModelChanged,
253    all_models: Vec<ModelInfo>,
254    filtered_models: Vec<ModelInfo>,
255    selected_index: usize,
256}
257
258impl PickerDelegate for LanguageModelPickerDelegate {
259    type ListItem = ListItem;
260
261    fn match_count(&self) -> usize {
262        self.filtered_models.len()
263    }
264
265    fn selected_index(&self) -> usize {
266        self.selected_index
267    }
268
269    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
270        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
271        cx.notify();
272    }
273
274    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
275        "Select a model...".into()
276    }
277
278    fn update_matches(
279        &mut self,
280        query: String,
281        window: &mut Window,
282        cx: &mut Context<Picker<Self>>,
283    ) -> Task<()> {
284        let all_models = self.all_models.clone();
285        let current_index = self.selected_index;
286
287        let language_model_registry = LanguageModelRegistry::global(cx);
288
289        let configured_providers = language_model_registry
290            .read(cx)
291            .providers()
292            .iter()
293            .filter(|provider| provider.is_authenticated(cx))
294            .map(|provider| provider.id())
295            .collect::<Vec<_>>();
296
297        cx.spawn_in(window, |this, mut cx| async move {
298            let filtered_models = cx
299                .background_spawn(async move {
300                    let displayed_models = if configured_providers.is_empty() {
301                        all_models
302                    } else {
303                        all_models
304                            .into_iter()
305                            .filter(|model_info| {
306                                configured_providers.contains(&model_info.model.provider_id())
307                            })
308                            .collect::<Vec<_>>()
309                    };
310
311                    if query.is_empty() {
312                        displayed_models
313                    } else {
314                        displayed_models
315                            .into_iter()
316                            .filter(|model_info| {
317                                model_info
318                                    .model
319                                    .name()
320                                    .0
321                                    .to_lowercase()
322                                    .contains(&query.to_lowercase())
323                            })
324                            .collect()
325                    }
326                })
327                .await;
328
329            this.update_in(&mut cx, |this, window, cx| {
330                this.delegate.filtered_models = filtered_models;
331                // Preserve selection focus
332                let new_index = if current_index >= this.delegate.filtered_models.len() {
333                    0
334                } else {
335                    current_index
336                };
337                this.delegate.set_selected_index(new_index, window, cx);
338                cx.notify();
339            })
340            .ok();
341        })
342    }
343
344    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
345        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
346            let model = model_info.model.clone();
347            (self.on_model_changed)(model.clone(), cx);
348
349            let current_index = self.selected_index;
350            self.set_selected_index(current_index, window, cx);
351
352            cx.emit(DismissEvent);
353        }
354    }
355
356    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
357        self.language_model_selector
358            .update(cx, |_this, cx| cx.emit(DismissEvent))
359            .ok();
360    }
361
362    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
363        let configured_models_count = LanguageModelRegistry::global(cx)
364            .read(cx)
365            .providers()
366            .iter()
367            .filter(|provider| provider.is_authenticated(cx))
368            .count();
369
370        if configured_models_count > 0 {
371            Some(
372                Label::new("Configured Models")
373                    .size(LabelSize::Small)
374                    .color(Color::Muted)
375                    .mt_1()
376                    .mb_0p5()
377                    .ml_2()
378                    .into_any_element(),
379            )
380        } else {
381            None
382        }
383    }
384
385    fn render_match(
386        &self,
387        ix: usize,
388        selected: bool,
389        _: &mut Window,
390        cx: &mut Context<Picker<Self>>,
391    ) -> Option<Self::ListItem> {
392        use feature_flags::FeatureFlagAppExt;
393        let show_badges = cx.has_flag::<ZedPro>();
394
395        let model_info = self.filtered_models.get(ix)?;
396        let provider_name: String = model_info.model.provider_name().0.clone().into();
397
398        let active_provider_id = LanguageModelRegistry::read_global(cx)
399            .active_provider()
400            .map(|m| m.id());
401
402        let active_model_id = LanguageModelRegistry::read_global(cx)
403            .active_model()
404            .map(|m| m.id());
405
406        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
407            && Some(model_info.model.id()) == active_model_id;
408
409        let model_icon_color = if is_selected {
410            Color::Accent
411        } else {
412            Color::Muted
413        };
414
415        Some(
416            ListItem::new(ix)
417                .inset(true)
418                .spacing(ListItemSpacing::Sparse)
419                .toggle_state(selected)
420                .start_slot(
421                    Icon::new(model_info.icon)
422                        .color(model_icon_color)
423                        .size(IconSize::Small),
424                )
425                .child(
426                    h_flex()
427                        .w_full()
428                        .items_center()
429                        .gap_1p5()
430                        .pl_0p5()
431                        .w(px(240.))
432                        .child(
433                            div().max_w_40().child(
434                                Label::new(model_info.model.name().0.clone()).text_ellipsis(),
435                            ),
436                        )
437                        .child(
438                            h_flex()
439                                .gap_0p5()
440                                .child(
441                                    Label::new(provider_name)
442                                        .size(LabelSize::XSmall)
443                                        .color(Color::Muted),
444                                )
445                                .children(match model_info.availability {
446                                    LanguageModelAvailability::Public => None,
447                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
448                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
449                                        show_badges.then(|| {
450                                            Label::new("Pro")
451                                                .size(LabelSize::XSmall)
452                                                .color(Color::Muted)
453                                        })
454                                    }
455                                }),
456                        ),
457                )
458                .end_slot(div().pr_3().when(is_selected, |this| {
459                    this.child(
460                        Icon::new(IconName::Check)
461                            .color(Color::Accent)
462                            .size(IconSize::Small),
463                    )
464                })),
465        )
466    }
467
468    fn render_footer(
469        &self,
470        _: &mut Window,
471        cx: &mut Context<Picker<Self>>,
472    ) -> Option<gpui::AnyElement> {
473        use feature_flags::FeatureFlagAppExt;
474
475        let plan = proto::Plan::ZedPro;
476        let is_trial = false;
477
478        Some(
479            h_flex()
480                .w_full()
481                .border_t_1()
482                .border_color(cx.theme().colors().border_variant)
483                .p_1()
484                .gap_4()
485                .justify_between()
486                .when(cx.has_flag::<ZedPro>(), |this| {
487                    this.child(match plan {
488                        // Already a Zed Pro subscriber
489                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
490                            .icon(IconName::ZedAssistant)
491                            .icon_size(IconSize::Small)
492                            .icon_color(Color::Muted)
493                            .icon_position(IconPosition::Start)
494                            .on_click(|_, window, cx| {
495                                window
496                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
497                            }),
498                        // Free user
499                        Plan::Free => Button::new(
500                            "try-pro",
501                            if is_trial {
502                                "Upgrade to Pro"
503                            } else {
504                                "Try Pro"
505                            },
506                        )
507                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
508                    })
509                })
510                .child(
511                    Button::new("configure", "Configure")
512                        .icon(IconName::Settings)
513                        .icon_size(IconSize::Small)
514                        .icon_color(Color::Muted)
515                        .icon_position(IconPosition::Start)
516                        .on_click(|_, window, cx| {
517                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
518                        }),
519                )
520                .into_any(),
521        )
522    }
523}