language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
  6    Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.entity().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: 0,
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx)
 45                .show_scrollbar(true)
 46                .max_height(Some(rems(20.).into()))
 47        });
 48
 49        LanguageModelSelector {
 50            picker,
 51            update_matches_task: None,
 52            _subscriptions: vec![cx.subscribe_in(
 53                &LanguageModelRegistry::global(cx),
 54                window,
 55                Self::handle_language_model_registry_event,
 56            )],
 57        }
 58    }
 59
 60    fn handle_language_model_registry_event(
 61        &mut self,
 62        _registry: &Entity<LanguageModelRegistry>,
 63        event: &language_model::Event,
 64        window: &mut Window,
 65        cx: &mut Context<Self>,
 66    ) {
 67        match event {
 68            language_model::Event::ProviderStateChanged
 69            | language_model::Event::AddedProvider(_)
 70            | language_model::Event::RemovedProvider(_) => {
 71                let task = self.picker.update(cx, |this, cx| {
 72                    let query = this.query(cx);
 73                    this.delegate.all_models = Self::all_models(cx);
 74                    this.delegate.update_matches(query, window, cx)
 75                });
 76                self.update_matches_task = Some(task);
 77            }
 78            _ => {}
 79        }
 80    }
 81
 82    fn all_models(cx: &App) -> Vec<ModelInfo> {
 83        LanguageModelRegistry::global(cx)
 84            .read(cx)
 85            .providers()
 86            .iter()
 87            .flat_map(|provider| {
 88                let icon = provider.icon();
 89
 90                provider.provided_models(cx).into_iter().map(move |model| {
 91                    let model = model.clone();
 92                    let icon = model.icon().unwrap_or(icon);
 93
 94                    ModelInfo {
 95                        model: model.clone(),
 96                        icon,
 97                        availability: model.availability(),
 98                    }
 99                })
100            })
101            .collect::<Vec<_>>()
102    }
103}
104
105impl EventEmitter<DismissEvent> for LanguageModelSelector {}
106
107impl Focusable for LanguageModelSelector {
108    fn focus_handle(&self, cx: &App) -> FocusHandle {
109        self.picker.focus_handle(cx)
110    }
111}
112
113impl Render for LanguageModelSelector {
114    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
115        self.picker.clone()
116    }
117}
118
119#[derive(IntoElement)]
120pub struct LanguageModelSelectorPopoverMenu<T, TT>
121where
122    T: PopoverTrigger + ButtonCommon,
123    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
124{
125    language_model_selector: Entity<LanguageModelSelector>,
126    trigger: T,
127    tooltip: TT,
128    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
129}
130
131impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
132where
133    T: PopoverTrigger + ButtonCommon,
134    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
135{
136    pub fn new(
137        language_model_selector: Entity<LanguageModelSelector>,
138        trigger: T,
139        tooltip: TT,
140    ) -> Self {
141        Self {
142            language_model_selector,
143            trigger,
144            tooltip,
145            handle: None,
146        }
147    }
148
149    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
150        self.handle = Some(handle);
151        self
152    }
153}
154
155impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
156where
157    T: PopoverTrigger + ButtonCommon,
158    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
159{
160    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
161        let language_model_selector = self.language_model_selector.clone();
162
163        PopoverMenu::new("model-switcher")
164            .menu(move |_window, _cx| Some(language_model_selector.clone()))
165            .trigger_with_tooltip(self.trigger, self.tooltip)
166            .anchor(gpui::Corner::BottomRight)
167            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
168            .offset(gpui::Point {
169                x: px(0.0),
170                y: px(-2.0),
171            })
172    }
173}
174
175#[derive(Clone)]
176struct ModelInfo {
177    model: Arc<dyn LanguageModel>,
178    icon: IconName,
179    availability: LanguageModelAvailability,
180}
181
182pub struct LanguageModelPickerDelegate {
183    language_model_selector: WeakEntity<LanguageModelSelector>,
184    on_model_changed: OnModelChanged,
185    all_models: Vec<ModelInfo>,
186    filtered_models: Vec<ModelInfo>,
187    selected_index: usize,
188}
189
190impl PickerDelegate for LanguageModelPickerDelegate {
191    type ListItem = ListItem;
192
193    fn match_count(&self) -> usize {
194        self.filtered_models.len()
195    }
196
197    fn selected_index(&self) -> usize {
198        self.selected_index
199    }
200
201    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
202        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
203        cx.notify();
204    }
205
206    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
207        "Select a model...".into()
208    }
209
210    fn update_matches(
211        &mut self,
212        query: String,
213        window: &mut Window,
214        cx: &mut Context<Picker<Self>>,
215    ) -> Task<()> {
216        let all_models = self.all_models.clone();
217        let current_index = self.selected_index;
218
219        let llm_registry = LanguageModelRegistry::global(cx);
220
221        let configured_providers = llm_registry
222            .read(cx)
223            .providers()
224            .iter()
225            .filter(|provider| provider.is_authenticated(cx))
226            .map(|provider| provider.id())
227            .collect::<Vec<_>>();
228
229        cx.spawn_in(window, |this, mut cx| async move {
230            let filtered_models = cx
231                .background_executor()
232                .spawn(async move {
233                    let displayed_models = if configured_providers.is_empty() {
234                        all_models
235                    } else {
236                        all_models
237                            .into_iter()
238                            .filter(|model_info| {
239                                configured_providers.contains(&model_info.model.provider_id())
240                            })
241                            .collect::<Vec<_>>()
242                    };
243
244                    if query.is_empty() {
245                        displayed_models
246                    } else {
247                        displayed_models
248                            .into_iter()
249                            .filter(|model_info| {
250                                model_info
251                                    .model
252                                    .name()
253                                    .0
254                                    .to_lowercase()
255                                    .contains(&query.to_lowercase())
256                            })
257                            .collect()
258                    }
259                })
260                .await;
261
262            this.update_in(&mut cx, |this, window, cx| {
263                this.delegate.filtered_models = filtered_models;
264                // Preserve selection focus
265                let new_index = if current_index >= this.delegate.filtered_models.len() {
266                    0
267                } else {
268                    current_index
269                };
270                this.delegate.set_selected_index(new_index, window, cx);
271                cx.notify();
272            })
273            .ok();
274        })
275    }
276
277    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
278        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
279            let model = model_info.model.clone();
280            (self.on_model_changed)(model.clone(), cx);
281
282            let current_index = self.selected_index;
283            self.set_selected_index(current_index, window, cx);
284
285            cx.emit(DismissEvent);
286        }
287    }
288
289    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
290        self.language_model_selector
291            .update(cx, |_this, cx| cx.emit(DismissEvent))
292            .ok();
293    }
294
295    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
296        let configured_models_count = LanguageModelRegistry::global(cx)
297            .read(cx)
298            .providers()
299            .iter()
300            .filter(|provider| provider.is_authenticated(cx))
301            .count();
302
303        if configured_models_count > 0 {
304            Some(
305                Label::new("Configured Models")
306                    .size(LabelSize::Small)
307                    .color(Color::Muted)
308                    .mt_1()
309                    .mb_0p5()
310                    .ml_2()
311                    .into_any_element(),
312            )
313        } else {
314            None
315        }
316    }
317
318    fn render_match(
319        &self,
320        ix: usize,
321        selected: bool,
322        _: &mut Window,
323        cx: &mut Context<Picker<Self>>,
324    ) -> Option<Self::ListItem> {
325        use feature_flags::FeatureFlagAppExt;
326        let show_badges = cx.has_flag::<ZedPro>();
327
328        let model_info = self.filtered_models.get(ix)?;
329        let provider_name: String = model_info.model.provider_name().0.clone().into();
330
331        let active_provider_id = LanguageModelRegistry::read_global(cx)
332            .active_provider()
333            .map(|m| m.id());
334
335        let active_model_id = LanguageModelRegistry::read_global(cx)
336            .active_model()
337            .map(|m| m.id());
338
339        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
340            && Some(model_info.model.id()) == active_model_id;
341
342        let model_icon_color = if is_selected {
343            Color::Accent
344        } else {
345            Color::Muted
346        };
347
348        Some(
349            ListItem::new(ix)
350                .inset(true)
351                .spacing(ListItemSpacing::Sparse)
352                .toggle_state(selected)
353                .start_slot(
354                    Icon::new(model_info.icon)
355                        .color(model_icon_color)
356                        .size(IconSize::Small),
357                )
358                .child(
359                    h_flex()
360                        .w_full()
361                        .items_center()
362                        .gap_1p5()
363                        .pl_0p5()
364                        .min_w(px(240.))
365                        .child(
366                            div().max_w_40().child(
367                                Label::new(model_info.model.name().0.clone()).text_ellipsis(),
368                            ),
369                        )
370                        .child(
371                            h_flex()
372                                .gap_0p5()
373                                .child(
374                                    Label::new(provider_name)
375                                        .size(LabelSize::XSmall)
376                                        .color(Color::Muted),
377                                )
378                                .children(match model_info.availability {
379                                    LanguageModelAvailability::Public => None,
380                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
381                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
382                                        show_badges.then(|| {
383                                            Label::new("Pro")
384                                                .size(LabelSize::XSmall)
385                                                .color(Color::Muted)
386                                        })
387                                    }
388                                }),
389                        ),
390                )
391                .end_slot(div().when(is_selected, |this| {
392                    this.child(
393                        Icon::new(IconName::Check)
394                            .color(Color::Accent)
395                            .size(IconSize::Small),
396                    )
397                })),
398        )
399    }
400
401    fn render_footer(
402        &self,
403        _: &mut Window,
404        cx: &mut Context<Picker<Self>>,
405    ) -> Option<gpui::AnyElement> {
406        use feature_flags::FeatureFlagAppExt;
407
408        let plan = proto::Plan::ZedPro;
409        let is_trial = false;
410
411        Some(
412            h_flex()
413                .w_full()
414                .border_t_1()
415                .border_color(cx.theme().colors().border_variant)
416                .p_1()
417                .gap_4()
418                .justify_between()
419                .when(cx.has_flag::<ZedPro>(), |this| {
420                    this.child(match plan {
421                        // Already a Zed Pro subscriber
422                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
423                            .icon(IconName::ZedAssistant)
424                            .icon_size(IconSize::Small)
425                            .icon_color(Color::Muted)
426                            .icon_position(IconPosition::Start)
427                            .on_click(|_, window, cx| {
428                                window
429                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
430                            }),
431                        // Free user
432                        Plan::Free => Button::new(
433                            "try-pro",
434                            if is_trial {
435                                "Upgrade to Pro"
436                            } else {
437                                "Try Pro"
438                            },
439                        )
440                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
441                    })
442                })
443                .child(
444                    Button::new("configure", "Configure")
445                        .icon(IconName::Settings)
446                        .icon_size(IconSize::Small)
447                        .icon_color(Color::Muted)
448                        .icon_position(IconPosition::Start)
449                        .on_click(|_, window, cx| {
450                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
451                        }),
452                )
453                .into_any(),
454        )
455    }
456}