language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
  6    Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.entity().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: 0,
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx)
 45                .show_scrollbar(true)
 46                .max_height(Some(rems(20.).into()))
 47        });
 48
 49        LanguageModelSelector {
 50            picker,
 51            update_matches_task: None,
 52            _subscriptions: vec![cx.subscribe_in(
 53                &LanguageModelRegistry::global(cx),
 54                window,
 55                Self::handle_language_model_registry_event,
 56            )],
 57        }
 58    }
 59
 60    fn handle_language_model_registry_event(
 61        &mut self,
 62        _registry: &Entity<LanguageModelRegistry>,
 63        event: &language_model::Event,
 64        window: &mut Window,
 65        cx: &mut Context<Self>,
 66    ) {
 67        match event {
 68            language_model::Event::ProviderStateChanged
 69            | language_model::Event::AddedProvider(_)
 70            | language_model::Event::RemovedProvider(_) => {
 71                let task = self.picker.update(cx, |this, cx| {
 72                    let query = this.query(cx);
 73                    this.delegate.all_models = Self::all_models(cx);
 74                    this.delegate.update_matches(query, window, cx)
 75                });
 76                self.update_matches_task = Some(task);
 77            }
 78            _ => {}
 79        }
 80    }
 81
 82    fn all_models(cx: &App) -> Vec<ModelInfo> {
 83        LanguageModelRegistry::global(cx)
 84            .read(cx)
 85            .providers()
 86            .iter()
 87            .flat_map(|provider| {
 88                let icon = provider.icon();
 89
 90                provider.provided_models(cx).into_iter().map(move |model| {
 91                    let model = model.clone();
 92                    let icon = model.icon().unwrap_or(icon);
 93
 94                    ModelInfo {
 95                        model: model.clone(),
 96                        icon,
 97                        availability: model.availability(),
 98                    }
 99                })
100            })
101            .collect::<Vec<_>>()
102    }
103}
104
105impl EventEmitter<DismissEvent> for LanguageModelSelector {}
106
107impl Focusable for LanguageModelSelector {
108    fn focus_handle(&self, cx: &App) -> FocusHandle {
109        self.picker.focus_handle(cx)
110    }
111}
112
113impl Render for LanguageModelSelector {
114    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
115        self.picker.clone()
116    }
117}
118
119#[derive(IntoElement)]
120pub struct LanguageModelSelectorPopoverMenu<T, TT>
121where
122    T: PopoverTrigger + ButtonCommon,
123    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
124{
125    language_model_selector: Entity<LanguageModelSelector>,
126    trigger: T,
127    tooltip: TT,
128    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
129}
130
131impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
132where
133    T: PopoverTrigger + ButtonCommon,
134    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
135{
136    pub fn new(
137        language_model_selector: Entity<LanguageModelSelector>,
138        trigger: T,
139        tooltip: TT,
140    ) -> Self {
141        Self {
142            language_model_selector,
143            trigger,
144            tooltip,
145            handle: None,
146        }
147    }
148
149    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
150        self.handle = Some(handle);
151        self
152    }
153}
154
155impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
156where
157    T: PopoverTrigger + ButtonCommon,
158    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
159{
160    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
161        let language_model_selector = self.language_model_selector.clone();
162
163        PopoverMenu::new("model-switcher")
164            .menu(move |_window, _cx| Some(language_model_selector.clone()))
165            .trigger_with_tooltip(self.trigger, self.tooltip)
166            .anchor(gpui::Corner::BottomRight)
167            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
168            .offset(gpui::Point {
169                x: px(0.0),
170                y: px(-2.0),
171            })
172    }
173}
174
175#[derive(Clone)]
176struct ModelInfo {
177    model: Arc<dyn LanguageModel>,
178    icon: IconName,
179    availability: LanguageModelAvailability,
180}
181
182pub struct LanguageModelPickerDelegate {
183    language_model_selector: WeakEntity<LanguageModelSelector>,
184    on_model_changed: OnModelChanged,
185    all_models: Vec<ModelInfo>,
186    filtered_models: Vec<ModelInfo>,
187    selected_index: usize,
188}
189
190impl PickerDelegate for LanguageModelPickerDelegate {
191    type ListItem = ListItem;
192
193    fn match_count(&self) -> usize {
194        self.filtered_models.len()
195    }
196
197    fn selected_index(&self) -> usize {
198        self.selected_index
199    }
200
201    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
202        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
203        cx.notify();
204    }
205
206    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
207        "Select a model...".into()
208    }
209
210    fn update_matches(
211        &mut self,
212        query: String,
213        window: &mut Window,
214        cx: &mut Context<Picker<Self>>,
215    ) -> Task<()> {
216        let all_models = self.all_models.clone();
217        let current_index = self.selected_index;
218
219        let llm_registry = LanguageModelRegistry::global(cx);
220
221        let configured_providers = llm_registry
222            .read(cx)
223            .providers()
224            .iter()
225            .filter(|provider| provider.is_authenticated(cx))
226            .map(|provider| provider.id())
227            .collect::<Vec<_>>();
228
229        cx.spawn_in(window, |this, mut cx| async move {
230            let filtered_models = cx
231                .background_spawn(async move {
232                    let displayed_models = if configured_providers.is_empty() {
233                        all_models
234                    } else {
235                        all_models
236                            .into_iter()
237                            .filter(|model_info| {
238                                configured_providers.contains(&model_info.model.provider_id())
239                            })
240                            .collect::<Vec<_>>()
241                    };
242
243                    if query.is_empty() {
244                        displayed_models
245                    } else {
246                        displayed_models
247                            .into_iter()
248                            .filter(|model_info| {
249                                model_info
250                                    .model
251                                    .name()
252                                    .0
253                                    .to_lowercase()
254                                    .contains(&query.to_lowercase())
255                            })
256                            .collect()
257                    }
258                })
259                .await;
260
261            this.update_in(&mut cx, |this, window, cx| {
262                this.delegate.filtered_models = filtered_models;
263                // Preserve selection focus
264                let new_index = if current_index >= this.delegate.filtered_models.len() {
265                    0
266                } else {
267                    current_index
268                };
269                this.delegate.set_selected_index(new_index, window, cx);
270                cx.notify();
271            })
272            .ok();
273        })
274    }
275
276    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
277        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
278            let model = model_info.model.clone();
279            (self.on_model_changed)(model.clone(), cx);
280
281            let current_index = self.selected_index;
282            self.set_selected_index(current_index, window, cx);
283
284            cx.emit(DismissEvent);
285        }
286    }
287
288    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
289        self.language_model_selector
290            .update(cx, |_this, cx| cx.emit(DismissEvent))
291            .ok();
292    }
293
294    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
295        let configured_models_count = LanguageModelRegistry::global(cx)
296            .read(cx)
297            .providers()
298            .iter()
299            .filter(|provider| provider.is_authenticated(cx))
300            .count();
301
302        if configured_models_count > 0 {
303            Some(
304                Label::new("Configured Models")
305                    .size(LabelSize::Small)
306                    .color(Color::Muted)
307                    .mt_1()
308                    .mb_0p5()
309                    .ml_2()
310                    .into_any_element(),
311            )
312        } else {
313            None
314        }
315    }
316
317    fn render_match(
318        &self,
319        ix: usize,
320        selected: bool,
321        _: &mut Window,
322        cx: &mut Context<Picker<Self>>,
323    ) -> Option<Self::ListItem> {
324        use feature_flags::FeatureFlagAppExt;
325        let show_badges = cx.has_flag::<ZedPro>();
326
327        let model_info = self.filtered_models.get(ix)?;
328        let provider_name: String = model_info.model.provider_name().0.clone().into();
329
330        let active_provider_id = LanguageModelRegistry::read_global(cx)
331            .active_provider()
332            .map(|m| m.id());
333
334        let active_model_id = LanguageModelRegistry::read_global(cx)
335            .active_model()
336            .map(|m| m.id());
337
338        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
339            && Some(model_info.model.id()) == active_model_id;
340
341        let model_icon_color = if is_selected {
342            Color::Accent
343        } else {
344            Color::Muted
345        };
346
347        Some(
348            ListItem::new(ix)
349                .inset(true)
350                .spacing(ListItemSpacing::Sparse)
351                .toggle_state(selected)
352                .start_slot(
353                    Icon::new(model_info.icon)
354                        .color(model_icon_color)
355                        .size(IconSize::Small),
356                )
357                .child(
358                    h_flex()
359                        .w_full()
360                        .items_center()
361                        .gap_1p5()
362                        .pl_0p5()
363                        .min_w(px(240.))
364                        .child(
365                            div().max_w_40().child(
366                                Label::new(model_info.model.name().0.clone()).text_ellipsis(),
367                            ),
368                        )
369                        .child(
370                            h_flex()
371                                .gap_0p5()
372                                .child(
373                                    Label::new(provider_name)
374                                        .size(LabelSize::XSmall)
375                                        .color(Color::Muted),
376                                )
377                                .children(match model_info.availability {
378                                    LanguageModelAvailability::Public => None,
379                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
380                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
381                                        show_badges.then(|| {
382                                            Label::new("Pro")
383                                                .size(LabelSize::XSmall)
384                                                .color(Color::Muted)
385                                        })
386                                    }
387                                }),
388                        ),
389                )
390                .end_slot(div().when(is_selected, |this| {
391                    this.child(
392                        Icon::new(IconName::Check)
393                            .color(Color::Accent)
394                            .size(IconSize::Small),
395                    )
396                })),
397        )
398    }
399
400    fn render_footer(
401        &self,
402        _: &mut Window,
403        cx: &mut Context<Picker<Self>>,
404    ) -> Option<gpui::AnyElement> {
405        use feature_flags::FeatureFlagAppExt;
406
407        let plan = proto::Plan::ZedPro;
408        let is_trial = false;
409
410        Some(
411            h_flex()
412                .w_full()
413                .border_t_1()
414                .border_color(cx.theme().colors().border_variant)
415                .p_1()
416                .gap_4()
417                .justify_between()
418                .when(cx.has_flag::<ZedPro>(), |this| {
419                    this.child(match plan {
420                        // Already a Zed Pro subscriber
421                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
422                            .icon(IconName::ZedAssistant)
423                            .icon_size(IconSize::Small)
424                            .icon_color(Color::Muted)
425                            .icon_position(IconPosition::Start)
426                            .on_click(|_, window, cx| {
427                                window
428                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
429                            }),
430                        // Free user
431                        Plan::Free => Button::new(
432                            "try-pro",
433                            if is_trial {
434                                "Upgrade to Pro"
435                            } else {
436                                "Try Pro"
437                            },
438                        )
439                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
440                    })
441                })
442                .child(
443                    Button::new("configure", "Configure")
444                        .icon(IconName::Settings)
445                        .icon_size(IconSize::Small)
446                        .icon_color(Color::Muted)
447                        .icon_position(IconPosition::Start)
448                        .on_click(|_, window, cx| {
449                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
450                        }),
451                )
452                .into_any(),
453        )
454    }
455}