language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
  6    Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.entity().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: 0,
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
 45        });
 46
 47        LanguageModelSelector {
 48            picker,
 49            update_matches_task: None,
 50            _subscriptions: vec![cx.subscribe_in(
 51                &LanguageModelRegistry::global(cx),
 52                window,
 53                Self::handle_language_model_registry_event,
 54            )],
 55        }
 56    }
 57
 58    fn handle_language_model_registry_event(
 59        &mut self,
 60        _registry: &Entity<LanguageModelRegistry>,
 61        event: &language_model::Event,
 62        window: &mut Window,
 63        cx: &mut Context<Self>,
 64    ) {
 65        match event {
 66            language_model::Event::ProviderStateChanged
 67            | language_model::Event::AddedProvider(_)
 68            | language_model::Event::RemovedProvider(_) => {
 69                let task = self.picker.update(cx, |this, cx| {
 70                    let query = this.query(cx);
 71                    this.delegate.all_models = Self::all_models(cx);
 72                    this.delegate.update_matches(query, window, cx)
 73                });
 74                self.update_matches_task = Some(task);
 75            }
 76            _ => {}
 77        }
 78    }
 79
 80    fn all_models(cx: &App) -> Vec<ModelInfo> {
 81        LanguageModelRegistry::global(cx)
 82            .read(cx)
 83            .providers()
 84            .iter()
 85            .flat_map(|provider| {
 86                let icon = provider.icon();
 87
 88                provider.provided_models(cx).into_iter().map(move |model| {
 89                    let model = model.clone();
 90                    let icon = model.icon().unwrap_or(icon);
 91
 92                    ModelInfo {
 93                        model: model.clone(),
 94                        icon,
 95                        availability: model.availability(),
 96                    }
 97                })
 98            })
 99            .collect::<Vec<_>>()
100    }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106    fn focus_handle(&self, cx: &App) -> FocusHandle {
107        self.picker.focus_handle(cx)
108    }
109}
110
111impl Render for LanguageModelSelector {
112    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113        self.picker.clone()
114    }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T>
119where
120    T: PopoverTrigger,
121{
122    language_model_selector: Entity<LanguageModelSelector>,
123    trigger: T,
124    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
125}
126
127impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
128    pub fn new(language_model_selector: Entity<LanguageModelSelector>, trigger: T) -> Self {
129        Self {
130            language_model_selector,
131            trigger,
132            handle: None,
133        }
134    }
135
136    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
137        self.handle = Some(handle);
138        self
139    }
140}
141
142impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
143    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
144        let language_model_selector = self.language_model_selector.clone();
145
146        PopoverMenu::new("model-switcher")
147            .menu(move |_window, _cx| Some(language_model_selector.clone()))
148            .trigger(self.trigger)
149            .anchor(gpui::Corner::BottomRight)
150            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
151            .offset(gpui::Point {
152                x: px(0.0),
153                y: px(-2.0),
154            })
155    }
156}
157
158#[derive(Clone)]
159struct ModelInfo {
160    model: Arc<dyn LanguageModel>,
161    icon: IconName,
162    availability: LanguageModelAvailability,
163}
164
165pub struct LanguageModelPickerDelegate {
166    language_model_selector: WeakEntity<LanguageModelSelector>,
167    on_model_changed: OnModelChanged,
168    all_models: Vec<ModelInfo>,
169    filtered_models: Vec<ModelInfo>,
170    selected_index: usize,
171}
172
173impl PickerDelegate for LanguageModelPickerDelegate {
174    type ListItem = ListItem;
175
176    fn match_count(&self) -> usize {
177        self.filtered_models.len()
178    }
179
180    fn selected_index(&self) -> usize {
181        self.selected_index
182    }
183
184    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
185        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
186        cx.notify();
187    }
188
189    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
190        "Select a model...".into()
191    }
192
193    fn update_matches(
194        &mut self,
195        query: String,
196        window: &mut Window,
197        cx: &mut Context<Picker<Self>>,
198    ) -> Task<()> {
199        let all_models = self.all_models.clone();
200
201        let llm_registry = LanguageModelRegistry::global(cx);
202
203        let configured_providers = llm_registry
204            .read(cx)
205            .providers()
206            .iter()
207            .filter(|provider| provider.is_authenticated(cx))
208            .map(|provider| provider.id())
209            .collect::<Vec<_>>();
210
211        cx.spawn_in(window, |this, mut cx| async move {
212            let filtered_models = cx
213                .background_executor()
214                .spawn(async move {
215                    let displayed_models = if configured_providers.is_empty() {
216                        all_models
217                    } else {
218                        all_models
219                            .into_iter()
220                            .filter(|model_info| {
221                                configured_providers.contains(&model_info.model.provider_id())
222                            })
223                            .collect::<Vec<_>>()
224                    };
225
226                    if query.is_empty() {
227                        displayed_models
228                    } else {
229                        displayed_models
230                            .into_iter()
231                            .filter(|model_info| {
232                                model_info
233                                    .model
234                                    .name()
235                                    .0
236                                    .to_lowercase()
237                                    .contains(&query.to_lowercase())
238                            })
239                            .collect()
240                    }
241                })
242                .await;
243
244            this.update_in(&mut cx, |this, window, cx| {
245                this.delegate.filtered_models = filtered_models;
246                this.delegate.set_selected_index(0, window, cx);
247                cx.notify();
248            })
249            .ok();
250        })
251    }
252
253    fn confirm(&mut self, _secondary: bool, _: &mut Window, cx: &mut Context<Picker<Self>>) {
254        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
255            let model = model_info.model.clone();
256            (self.on_model_changed)(model.clone(), cx);
257
258            cx.emit(DismissEvent);
259        }
260    }
261
262    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
263        self.language_model_selector
264            .update(cx, |_this, cx| cx.emit(DismissEvent))
265            .ok();
266    }
267
268    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
269        let configured_models_count = LanguageModelRegistry::global(cx)
270            .read(cx)
271            .providers()
272            .iter()
273            .filter(|provider| provider.is_authenticated(cx))
274            .count();
275
276        if configured_models_count > 0 {
277            Some(
278                Label::new("Configured Models")
279                    .size(LabelSize::Small)
280                    .color(Color::Muted)
281                    .mt_1()
282                    .mb_0p5()
283                    .ml_3()
284                    .into_any_element(),
285            )
286        } else {
287            None
288        }
289    }
290
291    fn render_match(
292        &self,
293        ix: usize,
294        selected: bool,
295        _: &mut Window,
296        cx: &mut Context<Picker<Self>>,
297    ) -> Option<Self::ListItem> {
298        use feature_flags::FeatureFlagAppExt;
299        let show_badges = cx.has_flag::<ZedPro>();
300
301        let model_info = self.filtered_models.get(ix)?;
302        let provider_name: String = model_info.model.provider_name().0.clone().into();
303
304        let active_provider_id = LanguageModelRegistry::read_global(cx)
305            .active_provider()
306            .map(|m| m.id());
307
308        let active_model_id = LanguageModelRegistry::read_global(cx)
309            .active_model()
310            .map(|m| m.id());
311
312        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
313            && Some(model_info.model.id()) == active_model_id;
314
315        Some(
316            ListItem::new(ix)
317                .inset(true)
318                .spacing(ListItemSpacing::Sparse)
319                .toggle_state(selected)
320                .start_slot(
321                    div().pr_0p5().child(
322                        Icon::new(model_info.icon)
323                            .color(Color::Muted)
324                            .size(IconSize::Medium),
325                    ),
326                )
327                .child(
328                    h_flex()
329                        .w_full()
330                        .items_center()
331                        .gap_1p5()
332                        .min_w(px(200.))
333                        .child(Label::new(model_info.model.name().0.clone()))
334                        .child(
335                            h_flex()
336                                .gap_0p5()
337                                .child(
338                                    Label::new(provider_name)
339                                        .size(LabelSize::XSmall)
340                                        .color(Color::Muted),
341                                )
342                                .children(match model_info.availability {
343                                    LanguageModelAvailability::Public => None,
344                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
345                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
346                                        show_badges.then(|| {
347                                            Label::new("Pro")
348                                                .size(LabelSize::XSmall)
349                                                .color(Color::Muted)
350                                        })
351                                    }
352                                }),
353                        ),
354                )
355                .end_slot(div().when(is_selected, |this| {
356                    this.child(
357                        Icon::new(IconName::Check)
358                            .color(Color::Accent)
359                            .size(IconSize::Small),
360                    )
361                })),
362        )
363    }
364
365    fn render_footer(
366        &self,
367        _: &mut Window,
368        cx: &mut Context<Picker<Self>>,
369    ) -> Option<gpui::AnyElement> {
370        use feature_flags::FeatureFlagAppExt;
371
372        let plan = proto::Plan::ZedPro;
373        let is_trial = false;
374
375        Some(
376            h_flex()
377                .w_full()
378                .border_t_1()
379                .border_color(cx.theme().colors().border_variant)
380                .p_1()
381                .gap_4()
382                .justify_between()
383                .when(cx.has_flag::<ZedPro>(), |this| {
384                    this.child(match plan {
385                        // Already a Zed Pro subscriber
386                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
387                            .icon(IconName::ZedAssistant)
388                            .icon_size(IconSize::Small)
389                            .icon_color(Color::Muted)
390                            .icon_position(IconPosition::Start)
391                            .on_click(|_, window, cx| {
392                                window
393                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
394                            }),
395                        // Free user
396                        Plan::Free => Button::new(
397                            "try-pro",
398                            if is_trial {
399                                "Upgrade to Pro"
400                            } else {
401                                "Try Pro"
402                            },
403                        )
404                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
405                    })
406                })
407                .child(
408                    Button::new("configure", "Configure")
409                        .icon(IconName::Settings)
410                        .icon_size(IconSize::Small)
411                        .icon_color(Color::Muted)
412                        .icon_position(IconPosition::Start)
413                        .on_click(|_, window, cx| {
414                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
415                        }),
416                )
417                .into_any(),
418        )
419    }
420}