language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
  6    Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.entity().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: 0,
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
 45        });
 46
 47        LanguageModelSelector {
 48            picker,
 49            update_matches_task: None,
 50            _subscriptions: vec![cx.subscribe_in(
 51                &LanguageModelRegistry::global(cx),
 52                window,
 53                Self::handle_language_model_registry_event,
 54            )],
 55        }
 56    }
 57
 58    fn handle_language_model_registry_event(
 59        &mut self,
 60        _registry: &Entity<LanguageModelRegistry>,
 61        event: &language_model::Event,
 62        window: &mut Window,
 63        cx: &mut Context<Self>,
 64    ) {
 65        match event {
 66            language_model::Event::ProviderStateChanged
 67            | language_model::Event::AddedProvider(_)
 68            | language_model::Event::RemovedProvider(_) => {
 69                let task = self.picker.update(cx, |this, cx| {
 70                    let query = this.query(cx);
 71                    this.delegate.all_models = Self::all_models(cx);
 72                    this.delegate.update_matches(query, window, cx)
 73                });
 74                self.update_matches_task = Some(task);
 75            }
 76            _ => {}
 77        }
 78    }
 79
 80    fn all_models(cx: &App) -> Vec<ModelInfo> {
 81        LanguageModelRegistry::global(cx)
 82            .read(cx)
 83            .providers()
 84            .iter()
 85            .flat_map(|provider| {
 86                let icon = provider.icon();
 87
 88                provider.provided_models(cx).into_iter().map(move |model| {
 89                    let model = model.clone();
 90                    let icon = model.icon().unwrap_or(icon);
 91
 92                    ModelInfo {
 93                        model: model.clone(),
 94                        icon,
 95                        availability: model.availability(),
 96                    }
 97                })
 98            })
 99            .collect::<Vec<_>>()
100    }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106    fn focus_handle(&self, cx: &App) -> FocusHandle {
107        self.picker.focus_handle(cx)
108    }
109}
110
111impl Render for LanguageModelSelector {
112    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113        self.picker.clone()
114    }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T>
119where
120    T: PopoverTrigger,
121{
122    language_model_selector: Entity<LanguageModelSelector>,
123    trigger: T,
124    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
125}
126
127impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
128    pub fn new(language_model_selector: Entity<LanguageModelSelector>, trigger: T) -> Self {
129        Self {
130            language_model_selector,
131            trigger,
132            handle: None,
133        }
134    }
135
136    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
137        self.handle = Some(handle);
138        self
139    }
140}
141
142impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
143    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
144        let language_model_selector = self.language_model_selector.clone();
145
146        PopoverMenu::new("model-switcher")
147            .menu(move |_window, _cx| Some(language_model_selector.clone()))
148            .trigger(self.trigger)
149            .anchor(gpui::Corner::BottomRight)
150            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
151            .offset(gpui::Point {
152                x: px(0.0),
153                y: px(-2.0),
154            })
155    }
156}
157
158#[derive(Clone)]
159struct ModelInfo {
160    model: Arc<dyn LanguageModel>,
161    icon: IconName,
162    availability: LanguageModelAvailability,
163}
164
165pub struct LanguageModelPickerDelegate {
166    language_model_selector: WeakEntity<LanguageModelSelector>,
167    on_model_changed: OnModelChanged,
168    all_models: Vec<ModelInfo>,
169    filtered_models: Vec<ModelInfo>,
170    selected_index: usize,
171}
172
173impl PickerDelegate for LanguageModelPickerDelegate {
174    type ListItem = ListItem;
175
176    fn match_count(&self) -> usize {
177        self.filtered_models.len()
178    }
179
180    fn selected_index(&self) -> usize {
181        self.selected_index
182    }
183
184    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
185        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
186        cx.notify();
187    }
188
189    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
190        "Select a model...".into()
191    }
192
193    fn update_matches(
194        &mut self,
195        query: String,
196        window: &mut Window,
197        cx: &mut Context<Picker<Self>>,
198    ) -> Task<()> {
199        let all_models = self.all_models.clone();
200        let current_index = self.selected_index;
201
202        let llm_registry = LanguageModelRegistry::global(cx);
203
204        let configured_providers = llm_registry
205            .read(cx)
206            .providers()
207            .iter()
208            .filter(|provider| provider.is_authenticated(cx))
209            .map(|provider| provider.id())
210            .collect::<Vec<_>>();
211
212        cx.spawn_in(window, |this, mut cx| async move {
213            let filtered_models = cx
214                .background_executor()
215                .spawn(async move {
216                    let displayed_models = if configured_providers.is_empty() {
217                        all_models
218                    } else {
219                        all_models
220                            .into_iter()
221                            .filter(|model_info| {
222                                configured_providers.contains(&model_info.model.provider_id())
223                            })
224                            .collect::<Vec<_>>()
225                    };
226
227                    if query.is_empty() {
228                        displayed_models
229                    } else {
230                        displayed_models
231                            .into_iter()
232                            .filter(|model_info| {
233                                model_info
234                                    .model
235                                    .name()
236                                    .0
237                                    .to_lowercase()
238                                    .contains(&query.to_lowercase())
239                            })
240                            .collect()
241                    }
242                })
243                .await;
244
245            this.update_in(&mut cx, |this, window, cx| {
246                this.delegate.filtered_models = filtered_models;
247                // Preserve selection focus
248                let new_index = if current_index >= this.delegate.filtered_models.len() {
249                    0
250                } else {
251                    current_index
252                };
253                this.delegate.set_selected_index(new_index, window, cx);
254                cx.notify();
255            })
256            .ok();
257        })
258    }
259
260    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
261        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
262            let model = model_info.model.clone();
263            (self.on_model_changed)(model.clone(), cx);
264
265            let current_index = self.selected_index;
266            self.set_selected_index(current_index, window, cx);
267
268            cx.emit(DismissEvent);
269        }
270    }
271
272    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
273        self.language_model_selector
274            .update(cx, |_this, cx| cx.emit(DismissEvent))
275            .ok();
276    }
277
278    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
279        let configured_models_count = LanguageModelRegistry::global(cx)
280            .read(cx)
281            .providers()
282            .iter()
283            .filter(|provider| provider.is_authenticated(cx))
284            .count();
285
286        if configured_models_count > 0 {
287            Some(
288                Label::new("Configured Models")
289                    .size(LabelSize::Small)
290                    .color(Color::Muted)
291                    .mt_1()
292                    .mb_0p5()
293                    .ml_3()
294                    .into_any_element(),
295            )
296        } else {
297            None
298        }
299    }
300
301    fn render_match(
302        &self,
303        ix: usize,
304        selected: bool,
305        _: &mut Window,
306        cx: &mut Context<Picker<Self>>,
307    ) -> Option<Self::ListItem> {
308        use feature_flags::FeatureFlagAppExt;
309        let show_badges = cx.has_flag::<ZedPro>();
310
311        let model_info = self.filtered_models.get(ix)?;
312        let provider_name: String = model_info.model.provider_name().0.clone().into();
313
314        let active_provider_id = LanguageModelRegistry::read_global(cx)
315            .active_provider()
316            .map(|m| m.id());
317
318        let active_model_id = LanguageModelRegistry::read_global(cx)
319            .active_model()
320            .map(|m| m.id());
321
322        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
323            && Some(model_info.model.id()) == active_model_id;
324
325        Some(
326            ListItem::new(ix)
327                .inset(true)
328                .spacing(ListItemSpacing::Sparse)
329                .toggle_state(selected)
330                .start_slot(
331                    div().pr_0p5().child(
332                        Icon::new(model_info.icon)
333                            .color(Color::Muted)
334                            .size(IconSize::Medium),
335                    ),
336                )
337                .child(
338                    h_flex()
339                        .w_full()
340                        .items_center()
341                        .gap_1p5()
342                        .min_w(px(200.))
343                        .child(Label::new(model_info.model.name().0.clone()))
344                        .child(
345                            h_flex()
346                                .gap_0p5()
347                                .child(
348                                    Label::new(provider_name)
349                                        .size(LabelSize::XSmall)
350                                        .color(Color::Muted),
351                                )
352                                .children(match model_info.availability {
353                                    LanguageModelAvailability::Public => None,
354                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
355                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
356                                        show_badges.then(|| {
357                                            Label::new("Pro")
358                                                .size(LabelSize::XSmall)
359                                                .color(Color::Muted)
360                                        })
361                                    }
362                                }),
363                        ),
364                )
365                .end_slot(div().when(is_selected, |this| {
366                    this.child(
367                        Icon::new(IconName::Check)
368                            .color(Color::Accent)
369                            .size(IconSize::Small),
370                    )
371                })),
372        )
373    }
374
375    fn render_footer(
376        &self,
377        _: &mut Window,
378        cx: &mut Context<Picker<Self>>,
379    ) -> Option<gpui::AnyElement> {
380        use feature_flags::FeatureFlagAppExt;
381
382        let plan = proto::Plan::ZedPro;
383        let is_trial = false;
384
385        Some(
386            h_flex()
387                .w_full()
388                .border_t_1()
389                .border_color(cx.theme().colors().border_variant)
390                .p_1()
391                .gap_4()
392                .justify_between()
393                .when(cx.has_flag::<ZedPro>(), |this| {
394                    this.child(match plan {
395                        // Already a Zed Pro subscriber
396                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
397                            .icon(IconName::ZedAssistant)
398                            .icon_size(IconSize::Small)
399                            .icon_color(Color::Muted)
400                            .icon_position(IconPosition::Start)
401                            .on_click(|_, window, cx| {
402                                window
403                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
404                            }),
405                        // Free user
406                        Plan::Free => Button::new(
407                            "try-pro",
408                            if is_trial {
409                                "Upgrade to Pro"
410                            } else {
411                                "Try Pro"
412                            },
413                        )
414                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
415                    })
416                })
417                .child(
418                    Button::new("configure", "Configure")
419                        .icon(IconName::Settings)
420                        .icon_size(IconSize::Small)
421                        .icon_color(Color::Muted)
422                        .icon_position(IconPosition::Start)
423                        .on_click(|_, window, cx| {
424                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
425                        }),
426                )
427                .into_any(),
428        )
429    }
430}