language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{Action, AnyElement, AppContext, DismissEvent, SharedString, Task};
  5use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  6use picker::{Picker, PickerDelegate};
  7use proto::Plan;
  8use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
  9use workspace::ShowConfiguration;
 10
 11const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 12
 13type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>;
 14
 15#[derive(IntoElement)]
 16pub struct LanguageModelSelector<T: PopoverTrigger> {
 17    handle: Option<PopoverMenuHandle<Picker<LanguageModelPickerDelegate>>>,
 18    on_model_changed: OnModelChanged,
 19    trigger: T,
 20    info_text: Option<SharedString>,
 21}
 22
 23pub struct LanguageModelPickerDelegate {
 24    on_model_changed: OnModelChanged,
 25    all_models: Vec<ModelInfo>,
 26    filtered_models: Vec<ModelInfo>,
 27    selected_index: usize,
 28}
 29
 30#[derive(Clone)]
 31struct ModelInfo {
 32    model: Arc<dyn LanguageModel>,
 33    icon: IconName,
 34    availability: LanguageModelAvailability,
 35    is_selected: bool,
 36}
 37
 38impl<T: PopoverTrigger> LanguageModelSelector<T> {
 39    pub fn new(
 40        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &AppContext) + 'static,
 41        trigger: T,
 42    ) -> Self {
 43        LanguageModelSelector {
 44            handle: None,
 45            on_model_changed: Arc::new(on_model_changed),
 46            trigger,
 47            info_text: None,
 48        }
 49    }
 50
 51    pub fn with_handle(
 52        mut self,
 53        handle: PopoverMenuHandle<Picker<LanguageModelPickerDelegate>>,
 54    ) -> Self {
 55        self.handle = Some(handle);
 56        self
 57    }
 58
 59    pub fn info_text(mut self, text: impl Into<SharedString>) -> Self {
 60        self.info_text = Some(text.into());
 61        self
 62    }
 63}
 64
 65impl PickerDelegate for LanguageModelPickerDelegate {
 66    type ListItem = ListItem;
 67
 68    fn match_count(&self) -> usize {
 69        self.filtered_models.len()
 70    }
 71
 72    fn selected_index(&self) -> usize {
 73        self.selected_index
 74    }
 75
 76    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
 77        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
 78        cx.notify();
 79    }
 80
 81    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
 82        "Select a model...".into()
 83    }
 84
 85    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
 86        let all_models = self.all_models.clone();
 87
 88        let llm_registry = LanguageModelRegistry::global(cx);
 89
 90        let configured_models: Vec<_> = llm_registry
 91            .read(cx)
 92            .providers()
 93            .iter()
 94            .filter(|provider| provider.is_authenticated(cx))
 95            .map(|provider| provider.id())
 96            .collect();
 97
 98        cx.spawn(|this, mut cx| async move {
 99            let filtered_models = cx
100                .background_executor()
101                .spawn(async move {
102                    let displayed_models = if configured_models.is_empty() {
103                        all_models
104                    } else {
105                        all_models
106                            .into_iter()
107                            .filter(|model_info| {
108                                configured_models.contains(&model_info.model.provider_id())
109                            })
110                            .collect::<Vec<_>>()
111                    };
112
113                    if query.is_empty() {
114                        displayed_models
115                    } else {
116                        displayed_models
117                            .into_iter()
118                            .filter(|model_info| {
119                                model_info
120                                    .model
121                                    .name()
122                                    .0
123                                    .to_lowercase()
124                                    .contains(&query.to_lowercase())
125                            })
126                            .collect()
127                    }
128                })
129                .await;
130
131            this.update(&mut cx, |this, cx| {
132                this.delegate.filtered_models = filtered_models;
133                this.delegate.set_selected_index(0, cx);
134                cx.notify();
135            })
136            .ok();
137        })
138    }
139
140    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
141        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
142            let model = model_info.model.clone();
143            (self.on_model_changed)(model.clone(), cx);
144
145            // Update the selection status
146            let selected_model_id = model_info.model.id();
147            let selected_provider_id = model_info.model.provider_id();
148            for model in &mut self.all_models {
149                model.is_selected = model.model.id() == selected_model_id
150                    && model.model.provider_id() == selected_provider_id;
151            }
152            for model in &mut self.filtered_models {
153                model.is_selected = model.model.id() == selected_model_id
154                    && model.model.provider_id() == selected_provider_id;
155            }
156
157            cx.emit(DismissEvent);
158        }
159    }
160
161    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
162
163    fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
164        let configured_models_count = LanguageModelRegistry::global(cx)
165            .read(cx)
166            .providers()
167            .iter()
168            .filter(|provider| provider.is_authenticated(cx))
169            .count();
170
171        if configured_models_count > 0 {
172            Some(
173                Label::new("Configured Models")
174                    .size(LabelSize::Small)
175                    .color(Color::Muted)
176                    .mt_1()
177                    .mb_0p5()
178                    .ml_3()
179                    .into_any_element(),
180            )
181        } else {
182            None
183        }
184    }
185
186    fn render_match(
187        &self,
188        ix: usize,
189        selected: bool,
190        cx: &mut ViewContext<Picker<Self>>,
191    ) -> Option<Self::ListItem> {
192        use feature_flags::FeatureFlagAppExt;
193        let show_badges = cx.has_flag::<ZedPro>();
194
195        let model_info = self.filtered_models.get(ix)?;
196        let provider_name: String = model_info.model.provider_name().0.clone().into();
197
198        Some(
199            ListItem::new(ix)
200                .inset(true)
201                .spacing(ListItemSpacing::Sparse)
202                .selected(selected)
203                .start_slot(
204                    div().pr_0p5().child(
205                        Icon::new(model_info.icon)
206                            .color(Color::Muted)
207                            .size(IconSize::Medium),
208                    ),
209                )
210                .child(
211                    h_flex()
212                        .w_full()
213                        .items_center()
214                        .gap_1p5()
215                        .min_w(px(200.))
216                        .child(Label::new(model_info.model.name().0.clone()))
217                        .child(
218                            h_flex()
219                                .gap_0p5()
220                                .child(
221                                    Label::new(provider_name)
222                                        .size(LabelSize::XSmall)
223                                        .color(Color::Muted),
224                                )
225                                .children(match model_info.availability {
226                                    LanguageModelAvailability::Public => None,
227                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
228                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
229                                        show_badges.then(|| {
230                                            Label::new("Pro")
231                                                .size(LabelSize::XSmall)
232                                                .color(Color::Muted)
233                                        })
234                                    }
235                                }),
236                        ),
237                )
238                .end_slot(div().when(model_info.is_selected, |this| {
239                    this.child(
240                        Icon::new(IconName::Check)
241                            .color(Color::Accent)
242                            .size(IconSize::Small),
243                    )
244                })),
245        )
246    }
247
248    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
249        use feature_flags::FeatureFlagAppExt;
250
251        let plan = proto::Plan::ZedPro;
252        let is_trial = false;
253
254        Some(
255            h_flex()
256                .w_full()
257                .border_t_1()
258                .border_color(cx.theme().colors().border_variant)
259                .p_1()
260                .gap_4()
261                .justify_between()
262                .when(cx.has_flag::<ZedPro>(), |this| {
263                    this.child(match plan {
264                        // Already a Zed Pro subscriber
265                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
266                            .icon(IconName::ZedAssistant)
267                            .icon_size(IconSize::Small)
268                            .icon_color(Color::Muted)
269                            .icon_position(IconPosition::Start)
270                            .on_click(|_, cx| {
271                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
272                            }),
273                        // Free user
274                        Plan::Free => Button::new(
275                            "try-pro",
276                            if is_trial {
277                                "Upgrade to Pro"
278                            } else {
279                                "Try Pro"
280                            },
281                        )
282                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
283                    })
284                })
285                .child(
286                    Button::new("configure", "Configure")
287                        .icon(IconName::Settings)
288                        .icon_size(IconSize::Small)
289                        .icon_color(Color::Muted)
290                        .icon_position(IconPosition::Start)
291                        .on_click(|_, cx| {
292                            cx.dispatch_action(ShowConfiguration.boxed_clone());
293                        }),
294                )
295                .into_any(),
296        )
297    }
298}
299
300impl<T: PopoverTrigger> RenderOnce for LanguageModelSelector<T> {
301    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
302        let selected_provider = LanguageModelRegistry::read_global(cx)
303            .active_provider()
304            .map(|m| m.id());
305
306        let selected_model = LanguageModelRegistry::read_global(cx)
307            .active_model()
308            .map(|m| m.id());
309
310        let all_models = LanguageModelRegistry::global(cx)
311            .read(cx)
312            .providers()
313            .iter()
314            .flat_map(|provider| {
315                let provider_id = provider.id();
316                let icon = provider.icon();
317                let selected_model = selected_model.clone();
318                let selected_provider = selected_provider.clone();
319
320                provider.provided_models(cx).into_iter().map(move |model| {
321                    let model = model.clone();
322                    let icon = model.icon().unwrap_or(icon);
323
324                    ModelInfo {
325                        model: model.clone(),
326                        icon,
327                        availability: model.availability(),
328                        is_selected: selected_model.as_ref() == Some(&model.id())
329                            && selected_provider.as_ref() == Some(&provider_id),
330                    }
331                })
332            })
333            .collect::<Vec<_>>();
334
335        let delegate = LanguageModelPickerDelegate {
336            on_model_changed: self.on_model_changed.clone(),
337            all_models: all_models.clone(),
338            filtered_models: all_models,
339            selected_index: 0,
340        };
341
342        let picker_view = cx.new_view(|cx| {
343            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
344            picker
345        });
346
347        PopoverMenu::new("model-switcher")
348            .menu(move |_cx| Some(picker_view.clone()))
349            .trigger(self.trigger)
350            .attach(gpui::AnchorCorner::BottomLeft)
351            .when_some(self.handle, |menu, handle| menu.with_handle(handle))
352    }
353}