language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
  6    View, WeakView,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: View<Picker<LanguageModelPickerDelegate>>,
 20}
 21
 22impl LanguageModelSelector {
 23    pub fn new(
 24        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &AppContext) + 'static,
 25        cx: &mut ViewContext<Self>,
 26    ) -> Self {
 27        let on_model_changed = Arc::new(on_model_changed);
 28
 29        let all_models = LanguageModelRegistry::global(cx)
 30            .read(cx)
 31            .providers()
 32            .iter()
 33            .flat_map(|provider| {
 34                let icon = provider.icon();
 35
 36                provider.provided_models(cx).into_iter().map(move |model| {
 37                    let model = model.clone();
 38                    let icon = model.icon().unwrap_or(icon);
 39
 40                    ModelInfo {
 41                        model: model.clone(),
 42                        icon,
 43                        availability: model.availability(),
 44                    }
 45                })
 46            })
 47            .collect::<Vec<_>>();
 48
 49        let delegate = LanguageModelPickerDelegate {
 50            language_model_selector: cx.view().downgrade(),
 51            on_model_changed: on_model_changed.clone(),
 52            all_models: all_models.clone(),
 53            filtered_models: all_models,
 54            selected_index: 0,
 55        };
 56
 57        let picker =
 58            cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
 59
 60        LanguageModelSelector { picker }
 61    }
 62}
 63
 64impl EventEmitter<DismissEvent> for LanguageModelSelector {}
 65
 66impl FocusableView for LanguageModelSelector {
 67    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
 68        self.picker.focus_handle(cx)
 69    }
 70}
 71
 72impl Render for LanguageModelSelector {
 73    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
 74        self.picker.clone()
 75    }
 76}
 77
 78#[derive(IntoElement)]
 79pub struct LanguageModelSelectorPopoverMenu<T>
 80where
 81    T: PopoverTrigger,
 82{
 83    language_model_selector: View<LanguageModelSelector>,
 84    trigger: T,
 85    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
 86}
 87
 88impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
 89    pub fn new(language_model_selector: View<LanguageModelSelector>, trigger: T) -> Self {
 90        Self {
 91            language_model_selector,
 92            trigger,
 93            handle: None,
 94        }
 95    }
 96
 97    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
 98        self.handle = Some(handle);
 99        self
100    }
101}
102
103impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
104    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
105        let language_model_selector = self.language_model_selector.clone();
106
107        PopoverMenu::new("model-switcher")
108            .menu(move |_cx| Some(language_model_selector.clone()))
109            .trigger(self.trigger)
110            .attach(gpui::Corner::BottomLeft)
111            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
112    }
113}
114
115#[derive(Clone)]
116struct ModelInfo {
117    model: Arc<dyn LanguageModel>,
118    icon: IconName,
119    availability: LanguageModelAvailability,
120}
121
122pub struct LanguageModelPickerDelegate {
123    language_model_selector: WeakView<LanguageModelSelector>,
124    on_model_changed: OnModelChanged,
125    all_models: Vec<ModelInfo>,
126    filtered_models: Vec<ModelInfo>,
127    selected_index: usize,
128}
129
130impl PickerDelegate for LanguageModelPickerDelegate {
131    type ListItem = ListItem;
132
133    fn match_count(&self) -> usize {
134        self.filtered_models.len()
135    }
136
137    fn selected_index(&self) -> usize {
138        self.selected_index
139    }
140
141    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
142        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
143        cx.notify();
144    }
145
146    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
147        "Select a model...".into()
148    }
149
150    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
151        let all_models = self.all_models.clone();
152
153        let llm_registry = LanguageModelRegistry::global(cx);
154
155        let configured_models: Vec<_> = llm_registry
156            .read(cx)
157            .providers()
158            .iter()
159            .filter(|provider| provider.is_authenticated(cx))
160            .map(|provider| provider.id())
161            .collect();
162
163        cx.spawn(|this, mut cx| async move {
164            let filtered_models = cx
165                .background_executor()
166                .spawn(async move {
167                    let displayed_models = if configured_models.is_empty() {
168                        all_models
169                    } else {
170                        all_models
171                            .into_iter()
172                            .filter(|model_info| {
173                                configured_models.contains(&model_info.model.provider_id())
174                            })
175                            .collect::<Vec<_>>()
176                    };
177
178                    if query.is_empty() {
179                        displayed_models
180                    } else {
181                        displayed_models
182                            .into_iter()
183                            .filter(|model_info| {
184                                model_info
185                                    .model
186                                    .name()
187                                    .0
188                                    .to_lowercase()
189                                    .contains(&query.to_lowercase())
190                            })
191                            .collect()
192                    }
193                })
194                .await;
195
196            this.update(&mut cx, |this, cx| {
197                this.delegate.filtered_models = filtered_models;
198                this.delegate.set_selected_index(0, cx);
199                cx.notify();
200            })
201            .ok();
202        })
203    }
204
205    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
206        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
207            let model = model_info.model.clone();
208            (self.on_model_changed)(model.clone(), cx);
209
210            cx.emit(DismissEvent);
211        }
212    }
213
214    fn dismissed(&mut self, cx: &mut ViewContext<Picker<Self>>) {
215        self.language_model_selector
216            .update(cx, |_this, cx| cx.emit(DismissEvent))
217            .ok();
218    }
219
220    fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
221        let configured_models_count = LanguageModelRegistry::global(cx)
222            .read(cx)
223            .providers()
224            .iter()
225            .filter(|provider| provider.is_authenticated(cx))
226            .count();
227
228        if configured_models_count > 0 {
229            Some(
230                Label::new("Configured Models")
231                    .size(LabelSize::Small)
232                    .color(Color::Muted)
233                    .mt_1()
234                    .mb_0p5()
235                    .ml_3()
236                    .into_any_element(),
237            )
238        } else {
239            None
240        }
241    }
242
243    fn render_match(
244        &self,
245        ix: usize,
246        selected: bool,
247        cx: &mut ViewContext<Picker<Self>>,
248    ) -> Option<Self::ListItem> {
249        use feature_flags::FeatureFlagAppExt;
250        let show_badges = cx.has_flag::<ZedPro>();
251
252        let model_info = self.filtered_models.get(ix)?;
253        let provider_name: String = model_info.model.provider_name().0.clone().into();
254
255        let active_provider_id = LanguageModelRegistry::read_global(cx)
256            .active_provider()
257            .map(|m| m.id());
258
259        let active_model_id = LanguageModelRegistry::read_global(cx)
260            .active_model()
261            .map(|m| m.id());
262
263        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
264            && Some(model_info.model.id()) == active_model_id;
265
266        Some(
267            ListItem::new(ix)
268                .inset(true)
269                .spacing(ListItemSpacing::Sparse)
270                .toggle_state(selected)
271                .start_slot(
272                    div().pr_0p5().child(
273                        Icon::new(model_info.icon)
274                            .color(Color::Muted)
275                            .size(IconSize::Medium),
276                    ),
277                )
278                .child(
279                    h_flex()
280                        .w_full()
281                        .items_center()
282                        .gap_1p5()
283                        .min_w(px(200.))
284                        .child(Label::new(model_info.model.name().0.clone()))
285                        .child(
286                            h_flex()
287                                .gap_0p5()
288                                .child(
289                                    Label::new(provider_name)
290                                        .size(LabelSize::XSmall)
291                                        .color(Color::Muted),
292                                )
293                                .children(match model_info.availability {
294                                    LanguageModelAvailability::Public => None,
295                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
296                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
297                                        show_badges.then(|| {
298                                            Label::new("Pro")
299                                                .size(LabelSize::XSmall)
300                                                .color(Color::Muted)
301                                        })
302                                    }
303                                }),
304                        ),
305                )
306                .end_slot(div().when(is_selected, |this| {
307                    this.child(
308                        Icon::new(IconName::Check)
309                            .color(Color::Accent)
310                            .size(IconSize::Small),
311                    )
312                })),
313        )
314    }
315
316    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
317        use feature_flags::FeatureFlagAppExt;
318
319        let plan = proto::Plan::ZedPro;
320        let is_trial = false;
321
322        Some(
323            h_flex()
324                .w_full()
325                .border_t_1()
326                .border_color(cx.theme().colors().border_variant)
327                .p_1()
328                .gap_4()
329                .justify_between()
330                .when(cx.has_flag::<ZedPro>(), |this| {
331                    this.child(match plan {
332                        // Already a Zed Pro subscriber
333                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
334                            .icon(IconName::ZedAssistant)
335                            .icon_size(IconSize::Small)
336                            .icon_color(Color::Muted)
337                            .icon_position(IconPosition::Start)
338                            .on_click(|_, cx| {
339                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
340                            }),
341                        // Free user
342                        Plan::Free => Button::new(
343                            "try-pro",
344                            if is_trial {
345                                "Upgrade to Pro"
346                            } else {
347                                "Try Pro"
348                            },
349                        )
350                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
351                    })
352                })
353                .child(
354                    Button::new("configure", "Configure")
355                        .icon(IconName::Settings)
356                        .icon_size(IconSize::Small)
357                        .icon_color(Color::Muted)
358                        .icon_position(IconPosition::Start)
359                        .on_click(|_, cx| {
360                            cx.dispatch_action(ShowConfiguration.boxed_clone());
361                        }),
362                )
363                .into_any(),
364        )
365    }
366}