language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
  6    Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.entity().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: 0,
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
 45        });
 46
 47        LanguageModelSelector {
 48            picker,
 49            update_matches_task: None,
 50            _subscriptions: vec![cx.subscribe_in(
 51                &LanguageModelRegistry::global(cx),
 52                window,
 53                Self::handle_language_model_registry_event,
 54            )],
 55        }
 56    }
 57
 58    fn handle_language_model_registry_event(
 59        &mut self,
 60        _registry: &Entity<LanguageModelRegistry>,
 61        event: &language_model::Event,
 62        window: &mut Window,
 63        cx: &mut Context<Self>,
 64    ) {
 65        match event {
 66            language_model::Event::ProviderStateChanged
 67            | language_model::Event::AddedProvider(_)
 68            | language_model::Event::RemovedProvider(_) => {
 69                let task = self.picker.update(cx, |this, cx| {
 70                    let query = this.query(cx);
 71                    this.delegate.all_models = Self::all_models(cx);
 72                    this.delegate.update_matches(query, window, cx)
 73                });
 74                self.update_matches_task = Some(task);
 75            }
 76            _ => {}
 77        }
 78    }
 79
 80    fn all_models(cx: &App) -> Vec<ModelInfo> {
 81        LanguageModelRegistry::global(cx)
 82            .read(cx)
 83            .providers()
 84            .iter()
 85            .flat_map(|provider| {
 86                let icon = provider.icon();
 87
 88                provider.provided_models(cx).into_iter().map(move |model| {
 89                    let model = model.clone();
 90                    let icon = model.icon().unwrap_or(icon);
 91
 92                    ModelInfo {
 93                        model: model.clone(),
 94                        icon,
 95                        availability: model.availability(),
 96                    }
 97                })
 98            })
 99            .collect::<Vec<_>>()
100    }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106    fn focus_handle(&self, cx: &App) -> FocusHandle {
107        self.picker.focus_handle(cx)
108    }
109}
110
111impl Render for LanguageModelSelector {
112    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113        self.picker.clone()
114    }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T, TT>
119where
120    T: PopoverTrigger + ButtonCommon,
121    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
122{
123    language_model_selector: Entity<LanguageModelSelector>,
124    trigger: T,
125    tooltip: TT,
126    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
127}
128
129impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
130where
131    T: PopoverTrigger + ButtonCommon,
132    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
133{
134    pub fn new(
135        language_model_selector: Entity<LanguageModelSelector>,
136        trigger: T,
137        tooltip: TT,
138    ) -> Self {
139        Self {
140            language_model_selector,
141            trigger,
142            tooltip,
143            handle: None,
144        }
145    }
146
147    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
148        self.handle = Some(handle);
149        self
150    }
151}
152
153impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
154where
155    T: PopoverTrigger + ButtonCommon,
156    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
157{
158    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
159        let language_model_selector = self.language_model_selector.clone();
160
161        PopoverMenu::new("model-switcher")
162            .menu(move |_window, _cx| Some(language_model_selector.clone()))
163            .trigger_with_tooltip(self.trigger, self.tooltip)
164            .anchor(gpui::Corner::BottomRight)
165            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
166            .offset(gpui::Point {
167                x: px(0.0),
168                y: px(-2.0),
169            })
170    }
171}
172
173#[derive(Clone)]
174struct ModelInfo {
175    model: Arc<dyn LanguageModel>,
176    icon: IconName,
177    availability: LanguageModelAvailability,
178}
179
180pub struct LanguageModelPickerDelegate {
181    language_model_selector: WeakEntity<LanguageModelSelector>,
182    on_model_changed: OnModelChanged,
183    all_models: Vec<ModelInfo>,
184    filtered_models: Vec<ModelInfo>,
185    selected_index: usize,
186}
187
188impl PickerDelegate for LanguageModelPickerDelegate {
189    type ListItem = ListItem;
190
191    fn match_count(&self) -> usize {
192        self.filtered_models.len()
193    }
194
195    fn selected_index(&self) -> usize {
196        self.selected_index
197    }
198
199    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
200        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
201        cx.notify();
202    }
203
204    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
205        "Select a model...".into()
206    }
207
208    fn update_matches(
209        &mut self,
210        query: String,
211        window: &mut Window,
212        cx: &mut Context<Picker<Self>>,
213    ) -> Task<()> {
214        let all_models = self.all_models.clone();
215        let current_index = self.selected_index;
216
217        let llm_registry = LanguageModelRegistry::global(cx);
218
219        let configured_providers = llm_registry
220            .read(cx)
221            .providers()
222            .iter()
223            .filter(|provider| provider.is_authenticated(cx))
224            .map(|provider| provider.id())
225            .collect::<Vec<_>>();
226
227        cx.spawn_in(window, |this, mut cx| async move {
228            let filtered_models = cx
229                .background_executor()
230                .spawn(async move {
231                    let displayed_models = if configured_providers.is_empty() {
232                        all_models
233                    } else {
234                        all_models
235                            .into_iter()
236                            .filter(|model_info| {
237                                configured_providers.contains(&model_info.model.provider_id())
238                            })
239                            .collect::<Vec<_>>()
240                    };
241
242                    if query.is_empty() {
243                        displayed_models
244                    } else {
245                        displayed_models
246                            .into_iter()
247                            .filter(|model_info| {
248                                model_info
249                                    .model
250                                    .name()
251                                    .0
252                                    .to_lowercase()
253                                    .contains(&query.to_lowercase())
254                            })
255                            .collect()
256                    }
257                })
258                .await;
259
260            this.update_in(&mut cx, |this, window, cx| {
261                this.delegate.filtered_models = filtered_models;
262                // Preserve selection focus
263                let new_index = if current_index >= this.delegate.filtered_models.len() {
264                    0
265                } else {
266                    current_index
267                };
268                this.delegate.set_selected_index(new_index, window, cx);
269                cx.notify();
270            })
271            .ok();
272        })
273    }
274
275    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
276        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
277            let model = model_info.model.clone();
278            (self.on_model_changed)(model.clone(), cx);
279
280            let current_index = self.selected_index;
281            self.set_selected_index(current_index, window, cx);
282
283            cx.emit(DismissEvent);
284        }
285    }
286
287    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
288        self.language_model_selector
289            .update(cx, |_this, cx| cx.emit(DismissEvent))
290            .ok();
291    }
292
293    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
294        let configured_models_count = LanguageModelRegistry::global(cx)
295            .read(cx)
296            .providers()
297            .iter()
298            .filter(|provider| provider.is_authenticated(cx))
299            .count();
300
301        if configured_models_count > 0 {
302            Some(
303                Label::new("Configured Models")
304                    .size(LabelSize::Small)
305                    .color(Color::Muted)
306                    .mt_1()
307                    .mb_0p5()
308                    .ml_3()
309                    .into_any_element(),
310            )
311        } else {
312            None
313        }
314    }
315
316    fn render_match(
317        &self,
318        ix: usize,
319        selected: bool,
320        _: &mut Window,
321        cx: &mut Context<Picker<Self>>,
322    ) -> Option<Self::ListItem> {
323        use feature_flags::FeatureFlagAppExt;
324        let show_badges = cx.has_flag::<ZedPro>();
325
326        let model_info = self.filtered_models.get(ix)?;
327        let provider_name: String = model_info.model.provider_name().0.clone().into();
328
329        let active_provider_id = LanguageModelRegistry::read_global(cx)
330            .active_provider()
331            .map(|m| m.id());
332
333        let active_model_id = LanguageModelRegistry::read_global(cx)
334            .active_model()
335            .map(|m| m.id());
336
337        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
338            && Some(model_info.model.id()) == active_model_id;
339
340        Some(
341            ListItem::new(ix)
342                .inset(true)
343                .spacing(ListItemSpacing::Sparse)
344                .toggle_state(selected)
345                .start_slot(
346                    div().pr_0p5().child(
347                        Icon::new(model_info.icon)
348                            .color(Color::Muted)
349                            .size(IconSize::Medium),
350                    ),
351                )
352                .child(
353                    h_flex()
354                        .w_full()
355                        .items_center()
356                        .gap_1p5()
357                        .min_w(px(200.))
358                        .child(Label::new(model_info.model.name().0.clone()))
359                        .child(
360                            h_flex()
361                                .gap_0p5()
362                                .child(
363                                    Label::new(provider_name)
364                                        .size(LabelSize::XSmall)
365                                        .color(Color::Muted),
366                                )
367                                .children(match model_info.availability {
368                                    LanguageModelAvailability::Public => None,
369                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
370                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
371                                        show_badges.then(|| {
372                                            Label::new("Pro")
373                                                .size(LabelSize::XSmall)
374                                                .color(Color::Muted)
375                                        })
376                                    }
377                                }),
378                        ),
379                )
380                .end_slot(div().when(is_selected, |this| {
381                    this.child(
382                        Icon::new(IconName::Check)
383                            .color(Color::Accent)
384                            .size(IconSize::Small),
385                    )
386                })),
387        )
388    }
389
390    fn render_footer(
391        &self,
392        _: &mut Window,
393        cx: &mut Context<Picker<Self>>,
394    ) -> Option<gpui::AnyElement> {
395        use feature_flags::FeatureFlagAppExt;
396
397        let plan = proto::Plan::ZedPro;
398        let is_trial = false;
399
400        Some(
401            h_flex()
402                .w_full()
403                .border_t_1()
404                .border_color(cx.theme().colors().border_variant)
405                .p_1()
406                .gap_4()
407                .justify_between()
408                .when(cx.has_flag::<ZedPro>(), |this| {
409                    this.child(match plan {
410                        // Already a Zed Pro subscriber
411                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
412                            .icon(IconName::ZedAssistant)
413                            .icon_size(IconSize::Small)
414                            .icon_color(Color::Muted)
415                            .icon_position(IconPosition::Start)
416                            .on_click(|_, window, cx| {
417                                window
418                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
419                            }),
420                        // Free user
421                        Plan::Free => Button::new(
422                            "try-pro",
423                            if is_trial {
424                                "Upgrade to Pro"
425                            } else {
426                                "Try Pro"
427                            },
428                        )
429                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
430                    })
431                })
432                .child(
433                    Button::new("configure", "Configure")
434                        .icon(IconName::Settings)
435                        .icon_size(IconSize::Small)
436                        .icon_color(Color::Muted)
437                        .icon_position(IconPosition::Start)
438                        .on_click(|_, window, cx| {
439                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
440                        }),
441                )
442                .into_any(),
443        )
444    }
445}