model_selector.rs

  1use feature_flags::ZedPro;
  2use gpui::DismissEvent;
  3use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  4use proto::Plan;
  5
  6use std::sync::Arc;
  7use ui::ListItemSpacing;
  8
  9use crate::assistant_settings::AssistantSettings;
 10use crate::ShowConfiguration;
 11use fs::Fs;
 12use gpui::Action;
 13use gpui::SharedString;
 14use gpui::Task;
 15use picker::{Picker, PickerDelegate};
 16use settings::update_settings_file;
 17use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 18
 19const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 20
 21#[derive(IntoElement)]
 22pub struct ModelSelector<T: PopoverTrigger> {
 23    handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
 24    fs: Arc<dyn Fs>,
 25    trigger: T,
 26    info_text: Option<SharedString>,
 27}
 28
 29pub struct ModelPickerDelegate {
 30    fs: Arc<dyn Fs>,
 31    all_models: Vec<ModelInfo>,
 32    filtered_models: Vec<ModelInfo>,
 33    selected_index: usize,
 34}
 35
 36#[derive(Clone)]
 37struct ModelInfo {
 38    model: Arc<dyn LanguageModel>,
 39    provider_icon: IconName,
 40    availability: LanguageModelAvailability,
 41    is_selected: bool,
 42}
 43
 44impl<T: PopoverTrigger> ModelSelector<T> {
 45    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 46        ModelSelector {
 47            handle: None,
 48            fs,
 49            trigger,
 50            info_text: None,
 51        }
 52    }
 53
 54    pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
 55        self.handle = Some(handle);
 56        self
 57    }
 58
 59    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 60        self.info_text = Some(text.into());
 61        self
 62    }
 63}
 64
 65impl PickerDelegate for ModelPickerDelegate {
 66    type ListItem = ListItem;
 67
 68    fn match_count(&self) -> usize {
 69        self.filtered_models.len()
 70    }
 71
 72    fn selected_index(&self) -> usize {
 73        self.selected_index
 74    }
 75
 76    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
 77        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
 78        cx.notify();
 79    }
 80
 81    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
 82        "Select a model...".into()
 83    }
 84
 85    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
 86        let all_models = self.all_models.clone();
 87        cx.spawn(|this, mut cx| async move {
 88            let filtered_models = cx
 89                .background_executor()
 90                .spawn(async move {
 91                    if query.is_empty() {
 92                        all_models
 93                    } else {
 94                        all_models
 95                            .into_iter()
 96                            .filter(|model_info| {
 97                                model_info
 98                                    .model
 99                                    .name()
100                                    .0
101                                    .to_lowercase()
102                                    .contains(&query.to_lowercase())
103                            })
104                            .collect()
105                    }
106                })
107                .await;
108
109            this.update(&mut cx, |this, cx| {
110                this.delegate.filtered_models = filtered_models;
111                this.delegate.set_selected_index(0, cx);
112                cx.notify();
113            })
114            .ok();
115        })
116    }
117
118    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
119        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
120            let model = model_info.model.clone();
121            update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
122                settings.set_model(model.clone())
123            });
124
125            // Update the selection status
126            let selected_model_id = model_info.model.id();
127            let selected_provider_id = model_info.model.provider_id();
128            for model in &mut self.all_models {
129                model.is_selected = model.model.id() == selected_model_id
130                    && model.model.provider_id() == selected_provider_id;
131            }
132            for model in &mut self.filtered_models {
133                model.is_selected = model.model.id() == selected_model_id
134                    && model.model.provider_id() == selected_provider_id;
135            }
136
137            cx.emit(DismissEvent);
138        }
139    }
140
141    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
142
143    fn render_match(
144        &self,
145        ix: usize,
146        selected: bool,
147        cx: &mut ViewContext<Picker<Self>>,
148    ) -> Option<Self::ListItem> {
149        use feature_flags::FeatureFlagAppExt;
150        let model_info = self.filtered_models.get(ix)?;
151        let show_badges = cx.has_flag::<ZedPro>();
152        Some(
153            ListItem::new(ix)
154                .inset(true)
155                .spacing(ListItemSpacing::Sparse)
156                .selected(selected)
157                .start_slot(
158                    div().pr_1().child(
159                        Icon::new(model_info.provider_icon)
160                            .color(Color::Muted)
161                            .size(IconSize::Medium),
162                    ),
163                )
164                .child(
165                    h_flex()
166                        .w_full()
167                        .justify_between()
168                        .font_buffer(cx)
169                        .min_w(px(200.))
170                        .child(
171                            h_flex()
172                                .gap_2()
173                                .child(Label::new(model_info.model.name().0.clone()))
174                                .children(match model_info.availability {
175                                    LanguageModelAvailability::Public => None,
176                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
177                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
178                                        show_badges.then(|| {
179                                            Label::new("Pro")
180                                                .size(LabelSize::XSmall)
181                                                .color(Color::Muted)
182                                        })
183                                    }
184                                }),
185                        )
186                        .child(div().when(model_info.is_selected, |this| {
187                            this.child(
188                                Icon::new(IconName::Check)
189                                    .color(Color::Accent)
190                                    .size(IconSize::Small),
191                            )
192                        })),
193                ),
194        )
195    }
196
197    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
198        use feature_flags::FeatureFlagAppExt;
199
200        let plan = proto::Plan::ZedPro;
201        let is_trial = false;
202
203        Some(
204            h_flex()
205                .w_full()
206                .border_t_1()
207                .border_color(cx.theme().colors().border)
208                .p_1()
209                .gap_4()
210                .justify_between()
211                .when(cx.has_flag::<ZedPro>(), |this| {
212                    this.child(match plan {
213                        // Already a zed pro subscriber
214                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
215                            .icon(IconName::ZedAssistant)
216                            .icon_size(IconSize::Small)
217                            .icon_color(Color::Muted)
218                            .icon_position(IconPosition::Start)
219                            .on_click(|_, cx| {
220                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
221                            }),
222                        // Free user
223                        Plan::Free => Button::new(
224                            "try-pro",
225                            if is_trial {
226                                "Upgrade to Pro"
227                            } else {
228                                "Try Pro"
229                            },
230                        )
231                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
232                    })
233                })
234                .child(
235                    Button::new("configure", "Configure")
236                        .icon(IconName::Settings)
237                        .icon_size(IconSize::Small)
238                        .icon_color(Color::Muted)
239                        .icon_position(IconPosition::Start)
240                        .on_click(|_, cx| {
241                            cx.dispatch_action(ShowConfiguration.boxed_clone());
242                        }),
243                )
244                .into_any(),
245        )
246    }
247}
248
249impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
250    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
251        let selected_provider = LanguageModelRegistry::read_global(cx)
252            .active_provider()
253            .map(|m| m.id());
254        let selected_model = LanguageModelRegistry::read_global(cx)
255            .active_model()
256            .map(|m| m.id());
257
258        let all_models = LanguageModelRegistry::global(cx)
259            .read(cx)
260            .providers()
261            .iter()
262            .flat_map(|provider| {
263                let provider_id = provider.id();
264                let provider_icon = provider.icon();
265                let selected_model = selected_model.clone();
266                let selected_provider = selected_provider.clone();
267
268                provider.provided_models(cx).into_iter().map(move |model| {
269                    let model = model.clone();
270
271                    ModelInfo {
272                        model: model.clone(),
273                        provider_icon,
274                        availability: model.availability(),
275                        is_selected: selected_model.as_ref() == Some(&model.id())
276                            && selected_provider.as_ref() == Some(&provider_id),
277                    }
278                })
279            })
280            .collect::<Vec<_>>();
281
282        let delegate = ModelPickerDelegate {
283            fs: self.fs.clone(),
284            all_models: all_models.clone(),
285            filtered_models: all_models,
286            selected_index: 0,
287        };
288
289        let picker_view = cx.new_view(|cx| {
290            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
291            picker
292        });
293
294        PopoverMenu::new("model-switcher")
295            .menu(move |_cx| Some(picker_view.clone()))
296            .trigger(self.trigger)
297            .attach(gpui::AnchorCorner::BottomLeft)
298    }
299}