language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
  6    Focusable, Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.entity().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: Self::get_active_model_index(cx),
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx)
 45                .show_scrollbar(true)
 46                .width(rems(20.))
 47                .max_height(Some(rems(20.).into()))
 48        });
 49
 50        LanguageModelSelector {
 51            picker,
 52            update_matches_task: None,
 53            _subscriptions: vec![cx.subscribe_in(
 54                &LanguageModelRegistry::global(cx),
 55                window,
 56                Self::handle_language_model_registry_event,
 57            )],
 58        }
 59    }
 60
 61    fn handle_language_model_registry_event(
 62        &mut self,
 63        _registry: &Entity<LanguageModelRegistry>,
 64        event: &language_model::Event,
 65        window: &mut Window,
 66        cx: &mut Context<Self>,
 67    ) {
 68        match event {
 69            language_model::Event::ProviderStateChanged
 70            | language_model::Event::AddedProvider(_)
 71            | language_model::Event::RemovedProvider(_) => {
 72                let task = self.picker.update(cx, |this, cx| {
 73                    let query = this.query(cx);
 74                    this.delegate.all_models = Self::all_models(cx);
 75                    this.delegate.update_matches(query, window, cx)
 76                });
 77                self.update_matches_task = Some(task);
 78            }
 79            _ => {}
 80        }
 81    }
 82
 83    fn all_models(cx: &App) -> Vec<ModelInfo> {
 84        LanguageModelRegistry::global(cx)
 85            .read(cx)
 86            .providers()
 87            .iter()
 88            .flat_map(|provider| {
 89                let icon = provider.icon();
 90
 91                provider.provided_models(cx).into_iter().map(move |model| {
 92                    let model = model.clone();
 93                    let icon = model.icon().unwrap_or(icon);
 94
 95                    ModelInfo {
 96                        model: model.clone(),
 97                        icon,
 98                        availability: model.availability(),
 99                    }
100                })
101            })
102            .collect::<Vec<_>>()
103    }
104
105    fn get_active_model_index(cx: &App) -> usize {
106        let active_model = LanguageModelRegistry::read_global(cx).active_model();
107        Self::all_models(cx)
108            .iter()
109            .position(|model_info| {
110                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
111            })
112            .unwrap_or(0)
113    }
114}
115
116impl EventEmitter<DismissEvent> for LanguageModelSelector {}
117
118impl Focusable for LanguageModelSelector {
119    fn focus_handle(&self, cx: &App) -> FocusHandle {
120        self.picker.focus_handle(cx)
121    }
122}
123
124impl Render for LanguageModelSelector {
125    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
126        self.picker.clone()
127    }
128}
129
130#[derive(IntoElement)]
131pub struct LanguageModelSelectorPopoverMenu<T, TT>
132where
133    T: PopoverTrigger + ButtonCommon,
134    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
135{
136    language_model_selector: Entity<LanguageModelSelector>,
137    trigger: T,
138    tooltip: TT,
139    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
140    anchor: Corner,
141}
142
143impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
144where
145    T: PopoverTrigger + ButtonCommon,
146    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
147{
148    pub fn new(
149        language_model_selector: Entity<LanguageModelSelector>,
150        trigger: T,
151        tooltip: TT,
152        anchor: Corner,
153    ) -> Self {
154        Self {
155            language_model_selector,
156            trigger,
157            tooltip,
158            handle: None,
159            anchor,
160        }
161    }
162
163    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
164        self.handle = Some(handle);
165        self
166    }
167}
168
169impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
170where
171    T: PopoverTrigger + ButtonCommon,
172    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
173{
174    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
175        let language_model_selector = self.language_model_selector.clone();
176
177        PopoverMenu::new("model-switcher")
178            .menu(move |_window, _cx| Some(language_model_selector.clone()))
179            .trigger_with_tooltip(self.trigger, self.tooltip)
180            .anchor(self.anchor)
181            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
182            .offset(gpui::Point {
183                x: px(0.0),
184                y: px(-2.0),
185            })
186    }
187}
188
189#[derive(Clone)]
190struct ModelInfo {
191    model: Arc<dyn LanguageModel>,
192    icon: IconName,
193    availability: LanguageModelAvailability,
194}
195
196pub struct LanguageModelPickerDelegate {
197    language_model_selector: WeakEntity<LanguageModelSelector>,
198    on_model_changed: OnModelChanged,
199    all_models: Vec<ModelInfo>,
200    filtered_models: Vec<ModelInfo>,
201    selected_index: usize,
202}
203
204impl PickerDelegate for LanguageModelPickerDelegate {
205    type ListItem = ListItem;
206
207    fn match_count(&self) -> usize {
208        self.filtered_models.len()
209    }
210
211    fn selected_index(&self) -> usize {
212        self.selected_index
213    }
214
215    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
216        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
217        cx.notify();
218    }
219
220    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
221        "Select a model...".into()
222    }
223
224    fn update_matches(
225        &mut self,
226        query: String,
227        window: &mut Window,
228        cx: &mut Context<Picker<Self>>,
229    ) -> Task<()> {
230        let all_models = self.all_models.clone();
231        let current_index = self.selected_index;
232
233        let language_model_registry = LanguageModelRegistry::global(cx);
234
235        let configured_providers = language_model_registry
236            .read(cx)
237            .providers()
238            .iter()
239            .filter(|provider| provider.is_authenticated(cx))
240            .map(|provider| provider.id())
241            .collect::<Vec<_>>();
242
243        cx.spawn_in(window, |this, mut cx| async move {
244            let filtered_models = cx
245                .background_spawn(async move {
246                    let displayed_models = if configured_providers.is_empty() {
247                        all_models
248                    } else {
249                        all_models
250                            .into_iter()
251                            .filter(|model_info| {
252                                configured_providers.contains(&model_info.model.provider_id())
253                            })
254                            .collect::<Vec<_>>()
255                    };
256
257                    if query.is_empty() {
258                        displayed_models
259                    } else {
260                        displayed_models
261                            .into_iter()
262                            .filter(|model_info| {
263                                model_info
264                                    .model
265                                    .name()
266                                    .0
267                                    .to_lowercase()
268                                    .contains(&query.to_lowercase())
269                            })
270                            .collect()
271                    }
272                })
273                .await;
274
275            this.update_in(&mut cx, |this, window, cx| {
276                this.delegate.filtered_models = filtered_models;
277                // Preserve selection focus
278                let new_index = if current_index >= this.delegate.filtered_models.len() {
279                    0
280                } else {
281                    current_index
282                };
283                this.delegate.set_selected_index(new_index, window, cx);
284                cx.notify();
285            })
286            .ok();
287        })
288    }
289
290    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
291        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
292            let model = model_info.model.clone();
293            (self.on_model_changed)(model.clone(), cx);
294
295            let current_index = self.selected_index;
296            self.set_selected_index(current_index, window, cx);
297
298            cx.emit(DismissEvent);
299        }
300    }
301
302    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
303        self.language_model_selector
304            .update(cx, |_this, cx| cx.emit(DismissEvent))
305            .ok();
306    }
307
308    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
309        let configured_models_count = LanguageModelRegistry::global(cx)
310            .read(cx)
311            .providers()
312            .iter()
313            .filter(|provider| provider.is_authenticated(cx))
314            .count();
315
316        if configured_models_count > 0 {
317            Some(
318                Label::new("Configured Models")
319                    .size(LabelSize::Small)
320                    .color(Color::Muted)
321                    .mt_1()
322                    .mb_0p5()
323                    .ml_2()
324                    .into_any_element(),
325            )
326        } else {
327            None
328        }
329    }
330
331    fn render_match(
332        &self,
333        ix: usize,
334        selected: bool,
335        _: &mut Window,
336        cx: &mut Context<Picker<Self>>,
337    ) -> Option<Self::ListItem> {
338        use feature_flags::FeatureFlagAppExt;
339        let show_badges = cx.has_flag::<ZedPro>();
340
341        let model_info = self.filtered_models.get(ix)?;
342        let provider_name: String = model_info.model.provider_name().0.clone().into();
343
344        let active_provider_id = LanguageModelRegistry::read_global(cx)
345            .active_provider()
346            .map(|m| m.id());
347
348        let active_model_id = LanguageModelRegistry::read_global(cx)
349            .active_model()
350            .map(|m| m.id());
351
352        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
353            && Some(model_info.model.id()) == active_model_id;
354
355        let model_icon_color = if is_selected {
356            Color::Accent
357        } else {
358            Color::Muted
359        };
360
361        Some(
362            ListItem::new(ix)
363                .inset(true)
364                .spacing(ListItemSpacing::Sparse)
365                .toggle_state(selected)
366                .start_slot(
367                    Icon::new(model_info.icon)
368                        .color(model_icon_color)
369                        .size(IconSize::Small),
370                )
371                .child(
372                    h_flex()
373                        .w_full()
374                        .items_center()
375                        .gap_1p5()
376                        .pl_0p5()
377                        .w(px(240.))
378                        .child(
379                            div().max_w_40().child(
380                                Label::new(model_info.model.name().0.clone()).text_ellipsis(),
381                            ),
382                        )
383                        .child(
384                            h_flex()
385                                .gap_0p5()
386                                .child(
387                                    Label::new(provider_name)
388                                        .size(LabelSize::XSmall)
389                                        .color(Color::Muted),
390                                )
391                                .children(match model_info.availability {
392                                    LanguageModelAvailability::Public => None,
393                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
394                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
395                                        show_badges.then(|| {
396                                            Label::new("Pro")
397                                                .size(LabelSize::XSmall)
398                                                .color(Color::Muted)
399                                        })
400                                    }
401                                }),
402                        ),
403                )
404                .end_slot(div().pr_3().when(is_selected, |this| {
405                    this.child(
406                        Icon::new(IconName::Check)
407                            .color(Color::Accent)
408                            .size(IconSize::Small),
409                    )
410                })),
411        )
412    }
413
414    fn render_footer(
415        &self,
416        _: &mut Window,
417        cx: &mut Context<Picker<Self>>,
418    ) -> Option<gpui::AnyElement> {
419        use feature_flags::FeatureFlagAppExt;
420
421        let plan = proto::Plan::ZedPro;
422        let is_trial = false;
423
424        Some(
425            h_flex()
426                .w_full()
427                .border_t_1()
428                .border_color(cx.theme().colors().border_variant)
429                .p_1()
430                .gap_4()
431                .justify_between()
432                .when(cx.has_flag::<ZedPro>(), |this| {
433                    this.child(match plan {
434                        // Already a Zed Pro subscriber
435                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
436                            .icon(IconName::ZedAssistant)
437                            .icon_size(IconSize::Small)
438                            .icon_color(Color::Muted)
439                            .icon_position(IconPosition::Start)
440                            .on_click(|_, window, cx| {
441                                window
442                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
443                            }),
444                        // Free user
445                        Plan::Free => Button::new(
446                            "try-pro",
447                            if is_trial {
448                                "Upgrade to Pro"
449                            } else {
450                                "Try Pro"
451                            },
452                        )
453                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
454                    })
455                })
456                .child(
457                    Button::new("configure", "Configure")
458                        .icon(IconName::Settings)
459                        .icon_size(IconSize::Small)
460                        .icon_color(Color::Muted)
461                        .icon_position(IconPosition::Start)
462                        .on_click(|_, window, cx| {
463                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
464                        }),
465                )
466                .into_any(),
467        )
468    }
469}