model_selector.rs

  1use feature_flags::ZedPro;
  2use gpui::Action;
  3use gpui::DismissEvent;
  4
  5use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  6use proto::Plan;
  7use workspace::ShowConfiguration;
  8
  9use std::sync::Arc;
 10use ui::ListItemSpacing;
 11
 12use crate::assistant_settings::AssistantSettings;
 13use fs::Fs;
 14use gpui::SharedString;
 15use gpui::Task;
 16use picker::{Picker, PickerDelegate};
 17use settings::update_settings_file;
 18use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 19
 20const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 21
 22#[derive(IntoElement)]
 23pub struct ModelSelector<T: PopoverTrigger> {
 24    handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
 25    fs: Arc<dyn Fs>,
 26    trigger: T,
 27    info_text: Option<SharedString>,
 28}
 29
 30pub struct ModelPickerDelegate {
 31    fs: Arc<dyn Fs>,
 32    all_models: Vec<ModelInfo>,
 33    filtered_models: Vec<ModelInfo>,
 34    selected_index: usize,
 35}
 36
 37#[derive(Clone)]
 38struct ModelInfo {
 39    model: Arc<dyn LanguageModel>,
 40    icon: IconName,
 41    availability: LanguageModelAvailability,
 42    is_selected: bool,
 43}
 44
 45impl<T: PopoverTrigger> ModelSelector<T> {
 46    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 47        ModelSelector {
 48            handle: None,
 49            fs,
 50            trigger,
 51            info_text: None,
 52        }
 53    }
 54
 55    pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
 56        self.handle = Some(handle);
 57        self
 58    }
 59
 60    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 61        self.info_text = Some(text.into());
 62        self
 63    }
 64}
 65
 66impl PickerDelegate for ModelPickerDelegate {
 67    type ListItem = ListItem;
 68
 69    fn match_count(&self) -> usize {
 70        self.filtered_models.len()
 71    }
 72
 73    fn selected_index(&self) -> usize {
 74        self.selected_index
 75    }
 76
 77    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
 78        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
 79        cx.notify();
 80    }
 81
 82    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
 83        "Select a model...".into()
 84    }
 85
 86    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
 87        let all_models = self.all_models.clone();
 88        cx.spawn(|this, mut cx| async move {
 89            let filtered_models = cx
 90                .background_executor()
 91                .spawn(async move {
 92                    if query.is_empty() {
 93                        all_models
 94                    } else {
 95                        all_models
 96                            .into_iter()
 97                            .filter(|model_info| {
 98                                model_info
 99                                    .model
100                                    .name()
101                                    .0
102                                    .to_lowercase()
103                                    .contains(&query.to_lowercase())
104                            })
105                            .collect()
106                    }
107                })
108                .await;
109
110            this.update(&mut cx, |this, cx| {
111                this.delegate.filtered_models = filtered_models;
112                this.delegate.set_selected_index(0, cx);
113                cx.notify();
114            })
115            .ok();
116        })
117    }
118
119    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
120        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
121            let model = model_info.model.clone();
122            update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
123                settings.set_model(model.clone())
124            });
125
126            // Update the selection status
127            let selected_model_id = model_info.model.id();
128            let selected_provider_id = model_info.model.provider_id();
129            for model in &mut self.all_models {
130                model.is_selected = model.model.id() == selected_model_id
131                    && model.model.provider_id() == selected_provider_id;
132            }
133            for model in &mut self.filtered_models {
134                model.is_selected = model.model.id() == selected_model_id
135                    && model.model.provider_id() == selected_provider_id;
136            }
137
138            cx.emit(DismissEvent);
139        }
140    }
141
142    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
143
144    fn render_match(
145        &self,
146        ix: usize,
147        selected: bool,
148        cx: &mut ViewContext<Picker<Self>>,
149    ) -> Option<Self::ListItem> {
150        use feature_flags::FeatureFlagAppExt;
151        let model_info = self.filtered_models.get(ix)?;
152        let show_badges = cx.has_flag::<ZedPro>();
153        let provider_name: String = model_info.model.provider_name().0.into();
154
155        Some(
156            ListItem::new(ix)
157                .inset(true)
158                .spacing(ListItemSpacing::Sparse)
159                .selected(selected)
160                .start_slot(
161                    div().pr_0p5().child(
162                        Icon::new(model_info.icon)
163                            .color(Color::Muted)
164                            .size(IconSize::Medium),
165                    ),
166                )
167                .child(
168                    h_flex().w_full().justify_between().min_w(px(200.)).child(
169                        h_flex()
170                            .gap_1p5()
171                            .child(Label::new(model_info.model.name().0.clone()))
172                            .child(
173                                Label::new(provider_name)
174                                    .size(LabelSize::XSmall)
175                                    .color(Color::Muted),
176                            )
177                            .children(match model_info.availability {
178                                LanguageModelAvailability::Public => None,
179                                LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
180                                LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
181                                    show_badges.then(|| {
182                                        Label::new("Pro")
183                                            .size(LabelSize::XSmall)
184                                            .color(Color::Muted)
185                                    })
186                                }
187                            }),
188                    ),
189                )
190                .end_slot(div().when(model_info.is_selected, |this| {
191                    this.child(
192                        Icon::new(IconName::Check)
193                            .color(Color::Accent)
194                            .size(IconSize::Small),
195                    )
196                })),
197        )
198    }
199
200    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
201        use feature_flags::FeatureFlagAppExt;
202
203        let plan = proto::Plan::ZedPro;
204        let is_trial = false;
205
206        Some(
207            h_flex()
208                .w_full()
209                .border_t_1()
210                .border_color(cx.theme().colors().border_variant)
211                .p_1()
212                .gap_4()
213                .justify_between()
214                .when(cx.has_flag::<ZedPro>(), |this| {
215                    this.child(match plan {
216                        // Already a zed pro subscriber
217                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
218                            .icon(IconName::ZedAssistant)
219                            .icon_size(IconSize::Small)
220                            .icon_color(Color::Muted)
221                            .icon_position(IconPosition::Start)
222                            .on_click(|_, cx| {
223                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
224                            }),
225                        // Free user
226                        Plan::Free => Button::new(
227                            "try-pro",
228                            if is_trial {
229                                "Upgrade to Pro"
230                            } else {
231                                "Try Pro"
232                            },
233                        )
234                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
235                    })
236                })
237                .child(
238                    Button::new("configure", "Configure")
239                        .icon(IconName::Settings)
240                        .icon_size(IconSize::Small)
241                        .icon_color(Color::Muted)
242                        .icon_position(IconPosition::Start)
243                        .on_click(|_, cx| {
244                            cx.dispatch_action(ShowConfiguration.boxed_clone());
245                        }),
246                )
247                .into_any(),
248        )
249    }
250}
251
252impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
253    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
254        let selected_provider = LanguageModelRegistry::read_global(cx)
255            .active_provider()
256            .map(|m| m.id());
257        let selected_model = LanguageModelRegistry::read_global(cx)
258            .active_model()
259            .map(|m| m.id());
260
261        let all_models = LanguageModelRegistry::global(cx)
262            .read(cx)
263            .providers()
264            .iter()
265            .flat_map(|provider| {
266                let provider_id = provider.id();
267                let icon = provider.icon();
268                let selected_model = selected_model.clone();
269                let selected_provider = selected_provider.clone();
270
271                provider.provided_models(cx).into_iter().map(move |model| {
272                    let model = model.clone();
273                    let icon = model.icon().unwrap_or(icon);
274
275                    ModelInfo {
276                        model: model.clone(),
277                        icon,
278                        availability: model.availability(),
279                        is_selected: selected_model.as_ref() == Some(&model.id())
280                            && selected_provider.as_ref() == Some(&provider_id),
281                    }
282                })
283            })
284            .collect::<Vec<_>>();
285
286        let delegate = ModelPickerDelegate {
287            fs: self.fs.clone(),
288            all_models: all_models.clone(),
289            filtered_models: all_models,
290            selected_index: 0,
291        };
292
293        let picker_view = cx.new_view(|cx| {
294            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
295            picker
296        });
297
298        PopoverMenu::new("model-switcher")
299            .menu(move |_cx| Some(picker_view.clone()))
300            .trigger(self.trigger)
301            .attach(gpui::AnchorCorner::BottomLeft)
302            .when_some(self.handle, |menu, handle| menu.with_handle(handle))
303    }
304}