language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  6    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
  7};
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
 10};
 11use picker::{Picker, PickerDelegate};
 12use proto::Plan;
 13use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 14use workspace::ShowConfiguration;
 15
 16action_with_deprecated_aliases!(
 17    assistant,
 18    ToggleModelSelector,
 19    ["assistant2::ToggleModelSelector"]
 20);
 21
 22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 23
 24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 25
 26pub struct LanguageModelSelector {
 27    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 28    /// The task used to update the picker's matches when there is a change to
 29    /// the language model registry.
 30    update_matches_task: Option<Task<()>>,
 31    _authenticate_all_providers_task: Task<()>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl LanguageModelSelector {
 36    pub fn new(
 37        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 38        window: &mut Window,
 39        cx: &mut Context<Self>,
 40    ) -> Self {
 41        let on_model_changed = Arc::new(on_model_changed);
 42
 43        let all_models = Self::all_models(cx);
 44        let delegate = LanguageModelPickerDelegate {
 45            language_model_selector: cx.entity().downgrade(),
 46            on_model_changed: on_model_changed.clone(),
 47            all_models: all_models.clone(),
 48            filtered_models: all_models,
 49            selected_index: Self::get_active_model_index(cx),
 50        };
 51
 52        let picker = cx.new(|cx| {
 53            Picker::uniform_list(delegate, window, cx)
 54                .show_scrollbar(true)
 55                .width(rems(20.))
 56                .max_height(Some(rems(20.).into()))
 57        });
 58
 59        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 60
 61        LanguageModelSelector {
 62            picker,
 63            update_matches_task: None,
 64            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 65            _subscriptions: vec![
 66                cx.subscribe_in(
 67                    &LanguageModelRegistry::global(cx),
 68                    window,
 69                    Self::handle_language_model_registry_event,
 70                ),
 71                subscription,
 72            ],
 73        }
 74    }
 75
 76    fn handle_language_model_registry_event(
 77        &mut self,
 78        _registry: &Entity<LanguageModelRegistry>,
 79        event: &language_model::Event,
 80        window: &mut Window,
 81        cx: &mut Context<Self>,
 82    ) {
 83        match event {
 84            language_model::Event::ProviderStateChanged
 85            | language_model::Event::AddedProvider(_)
 86            | language_model::Event::RemovedProvider(_) => {
 87                let task = self.picker.update(cx, |this, cx| {
 88                    let query = this.query(cx);
 89                    this.delegate.all_models = Self::all_models(cx);
 90                    this.delegate.update_matches(query, window, cx)
 91                });
 92                self.update_matches_task = Some(task);
 93            }
 94            _ => {}
 95        }
 96    }
 97
 98    /// Authenticates all providers in the [`LanguageModelRegistry`].
 99    ///
100    /// We do this so that we can populate the language selector with all of the
101    /// models from the configured providers.
102    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
103        let authenticate_all_providers = LanguageModelRegistry::global(cx)
104            .read(cx)
105            .providers()
106            .iter()
107            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
108            .collect::<Vec<_>>();
109
110        cx.spawn(async move |_cx| {
111            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
112                if let Err(err) = authenticate_task.await {
113                    if matches!(err, AuthenticateError::CredentialsNotFound) {
114                        // Since we're authenticating these providers in the
115                        // background for the purposes of populating the
116                        // language selector, we don't care about providers
117                        // where the credentials are not found.
118                    } else {
119                        // Some providers have noisy failure states that we
120                        // don't want to spam the logs with every time the
121                        // language model selector is initialized.
122                        //
123                        // Ideally these should have more clear failure modes
124                        // that we know are safe to ignore here, like what we do
125                        // with `CredentialsNotFound` above.
126                        match provider_id.0.as_ref() {
127                            "lmstudio" | "ollama" => {
128                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
129                                //
130                                // These fail noisily, so we don't log them.
131                            }
132                            "copilot_chat" => {
133                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
134                            }
135                            _ => {
136                                log::error!(
137                                    "Failed to authenticate provider: {}: {err}",
138                                    provider_name.0
139                                );
140                            }
141                        }
142                    }
143                }
144            }
145        })
146    }
147
148    fn all_models(cx: &App) -> Vec<ModelInfo> {
149        LanguageModelRegistry::global(cx)
150            .read(cx)
151            .providers()
152            .iter()
153            .flat_map(|provider| {
154                let icon = provider.icon();
155
156                provider.provided_models(cx).into_iter().map(move |model| {
157                    let model = model.clone();
158                    let icon = model.icon().unwrap_or(icon);
159
160                    ModelInfo {
161                        model: model.clone(),
162                        icon,
163                        availability: model.availability(),
164                    }
165                })
166            })
167            .collect::<Vec<_>>()
168    }
169
170    fn get_active_model_index(cx: &App) -> usize {
171        let active_model = LanguageModelRegistry::read_global(cx).active_model();
172        Self::all_models(cx)
173            .iter()
174            .position(|model_info| {
175                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
176            })
177            .unwrap_or(0)
178    }
179}
180
181impl EventEmitter<DismissEvent> for LanguageModelSelector {}
182
183impl Focusable for LanguageModelSelector {
184    fn focus_handle(&self, cx: &App) -> FocusHandle {
185        self.picker.focus_handle(cx)
186    }
187}
188
189impl Render for LanguageModelSelector {
190    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
191        self.picker.clone()
192    }
193}
194
195#[derive(IntoElement)]
196pub struct LanguageModelSelectorPopoverMenu<T, TT>
197where
198    T: PopoverTrigger + ButtonCommon,
199    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
200{
201    language_model_selector: Entity<LanguageModelSelector>,
202    trigger: T,
203    tooltip: TT,
204    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
205    anchor: Corner,
206}
207
208impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
209where
210    T: PopoverTrigger + ButtonCommon,
211    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
212{
213    pub fn new(
214        language_model_selector: Entity<LanguageModelSelector>,
215        trigger: T,
216        tooltip: TT,
217        anchor: Corner,
218    ) -> Self {
219        Self {
220            language_model_selector,
221            trigger,
222            tooltip,
223            handle: None,
224            anchor,
225        }
226    }
227
228    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
229        self.handle = Some(handle);
230        self
231    }
232}
233
234impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
235where
236    T: PopoverTrigger + ButtonCommon,
237    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
238{
239    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
240        let language_model_selector = self.language_model_selector.clone();
241
242        PopoverMenu::new("model-switcher")
243            .menu(move |_window, _cx| Some(language_model_selector.clone()))
244            .trigger_with_tooltip(self.trigger, self.tooltip)
245            .anchor(self.anchor)
246            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
247            .offset(gpui::Point {
248                x: px(0.0),
249                y: px(-2.0),
250            })
251    }
252}
253
254#[derive(Clone)]
255struct ModelInfo {
256    model: Arc<dyn LanguageModel>,
257    icon: IconName,
258    availability: LanguageModelAvailability,
259}
260
261pub struct LanguageModelPickerDelegate {
262    language_model_selector: WeakEntity<LanguageModelSelector>,
263    on_model_changed: OnModelChanged,
264    all_models: Vec<ModelInfo>,
265    filtered_models: Vec<ModelInfo>,
266    selected_index: usize,
267}
268
269impl PickerDelegate for LanguageModelPickerDelegate {
270    type ListItem = ListItem;
271
272    fn match_count(&self) -> usize {
273        self.filtered_models.len()
274    }
275
276    fn selected_index(&self) -> usize {
277        self.selected_index
278    }
279
280    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
281        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
282        cx.notify();
283    }
284
285    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
286        "Select a model...".into()
287    }
288
289    fn update_matches(
290        &mut self,
291        query: String,
292        window: &mut Window,
293        cx: &mut Context<Picker<Self>>,
294    ) -> Task<()> {
295        let all_models = self.all_models.clone();
296        let current_index = self.selected_index;
297
298        let language_model_registry = LanguageModelRegistry::global(cx);
299
300        let configured_providers = language_model_registry
301            .read(cx)
302            .providers()
303            .iter()
304            .filter(|provider| provider.is_authenticated(cx))
305            .map(|provider| provider.id())
306            .collect::<Vec<_>>();
307
308        cx.spawn_in(window, async move |this, cx| {
309            let filtered_models = cx
310                .background_spawn(async move {
311                    let displayed_models = if configured_providers.is_empty() {
312                        all_models
313                    } else {
314                        all_models
315                            .into_iter()
316                            .filter(|model_info| {
317                                configured_providers.contains(&model_info.model.provider_id())
318                            })
319                            .collect::<Vec<_>>()
320                    };
321
322                    if query.is_empty() {
323                        displayed_models
324                    } else {
325                        displayed_models
326                            .into_iter()
327                            .filter(|model_info| {
328                                model_info
329                                    .model
330                                    .name()
331                                    .0
332                                    .to_lowercase()
333                                    .contains(&query.to_lowercase())
334                            })
335                            .collect()
336                    }
337                })
338                .await;
339
340            this.update_in(cx, |this, window, cx| {
341                this.delegate.filtered_models = filtered_models;
342                // Preserve selection focus
343                let new_index = if current_index >= this.delegate.filtered_models.len() {
344                    0
345                } else {
346                    current_index
347                };
348                this.delegate.set_selected_index(new_index, window, cx);
349                cx.notify();
350            })
351            .ok();
352        })
353    }
354
355    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
356        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
357            let model = model_info.model.clone();
358            (self.on_model_changed)(model.clone(), cx);
359
360            let current_index = self.selected_index;
361            self.set_selected_index(current_index, window, cx);
362
363            cx.emit(DismissEvent);
364        }
365    }
366
367    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
368        self.language_model_selector
369            .update(cx, |_this, cx| cx.emit(DismissEvent))
370            .ok();
371    }
372
373    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
374        let configured_models_count = LanguageModelRegistry::global(cx)
375            .read(cx)
376            .providers()
377            .iter()
378            .filter(|provider| provider.is_authenticated(cx))
379            .count();
380
381        if configured_models_count > 0 {
382            Some(
383                Label::new("Configured Models")
384                    .size(LabelSize::Small)
385                    .color(Color::Muted)
386                    .mt_1()
387                    .mb_0p5()
388                    .ml_2()
389                    .into_any_element(),
390            )
391        } else {
392            None
393        }
394    }
395
396    fn render_match(
397        &self,
398        ix: usize,
399        selected: bool,
400        _: &mut Window,
401        cx: &mut Context<Picker<Self>>,
402    ) -> Option<Self::ListItem> {
403        use feature_flags::FeatureFlagAppExt;
404        let show_badges = cx.has_flag::<ZedPro>();
405
406        let model_info = self.filtered_models.get(ix)?;
407        let provider_name: String = model_info.model.provider_name().0.clone().into();
408
409        let active_provider_id = LanguageModelRegistry::read_global(cx)
410            .active_provider()
411            .map(|m| m.id());
412
413        let active_model_id = LanguageModelRegistry::read_global(cx)
414            .active_model()
415            .map(|m| m.id());
416
417        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
418            && Some(model_info.model.id()) == active_model_id;
419
420        let model_icon_color = if is_selected {
421            Color::Accent
422        } else {
423            Color::Muted
424        };
425
426        Some(
427            ListItem::new(ix)
428                .inset(true)
429                .spacing(ListItemSpacing::Sparse)
430                .toggle_state(selected)
431                .start_slot(
432                    Icon::new(model_info.icon)
433                        .color(model_icon_color)
434                        .size(IconSize::Small),
435                )
436                .child(
437                    h_flex()
438                        .w_full()
439                        .items_center()
440                        .gap_1p5()
441                        .pl_0p5()
442                        .w(px(240.))
443                        .child(
444                            div()
445                                .max_w_40()
446                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
447                        )
448                        .child(
449                            h_flex()
450                                .gap_0p5()
451                                .child(
452                                    Label::new(provider_name)
453                                        .size(LabelSize::XSmall)
454                                        .color(Color::Muted),
455                                )
456                                .children(match model_info.availability {
457                                    LanguageModelAvailability::Public => None,
458                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
459                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
460                                        show_badges.then(|| {
461                                            Label::new("Pro")
462                                                .size(LabelSize::XSmall)
463                                                .color(Color::Muted)
464                                        })
465                                    }
466                                }),
467                        ),
468                )
469                .end_slot(div().pr_3().when(is_selected, |this| {
470                    this.child(
471                        Icon::new(IconName::Check)
472                            .color(Color::Accent)
473                            .size(IconSize::Small),
474                    )
475                })),
476        )
477    }
478
479    fn render_footer(
480        &self,
481        _: &mut Window,
482        cx: &mut Context<Picker<Self>>,
483    ) -> Option<gpui::AnyElement> {
484        use feature_flags::FeatureFlagAppExt;
485
486        let plan = proto::Plan::ZedPro;
487        let is_trial = false;
488
489        Some(
490            h_flex()
491                .w_full()
492                .border_t_1()
493                .border_color(cx.theme().colors().border_variant)
494                .p_1()
495                .gap_4()
496                .justify_between()
497                .when(cx.has_flag::<ZedPro>(), |this| {
498                    this.child(match plan {
499                        // Already a Zed Pro subscriber
500                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
501                            .icon(IconName::ZedAssistant)
502                            .icon_size(IconSize::Small)
503                            .icon_color(Color::Muted)
504                            .icon_position(IconPosition::Start)
505                            .on_click(|_, window, cx| {
506                                window
507                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
508                            }),
509                        // Free user
510                        Plan::Free => Button::new(
511                            "try-pro",
512                            if is_trial {
513                                "Upgrade to Pro"
514                            } else {
515                                "Try Pro"
516                            },
517                        )
518                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
519                    })
520                })
521                .child(
522                    Button::new("configure", "Configure")
523                        .icon(IconName::Settings)
524                        .icon_size(IconSize::Small)
525                        .icon_color(Color::Muted)
526                        .icon_position(IconPosition::Start)
527                        .on_click(|_, window, cx| {
528                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
529                        }),
530                )
531                .into_any(),
532        )
533    }
534}